spark-framework 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,6 @@
1
+ from spark_framework.core.config import PipelineConfig
2
+ from spark_framework.core.pipeline import Pipeline, PipelineResult
3
+ from spark_framework.framework import SparkFramework
4
+
5
+ __version__ = "0.2.0"
6
+ __all__ = ["SparkFramework", "Pipeline", "PipelineResult", "PipelineConfig"]
spark_framework/cli.py ADDED
@@ -0,0 +1,52 @@
1
+ """CLI entry point para o SparkFramework (spark-framework <config.json>)."""
2
+
3
+ import argparse
4
+ import sys
5
+
6
+ from spark_framework import SparkFramework
7
+
8
+
9
+ def build_parser() -> argparse.ArgumentParser:
10
+ parser = argparse.ArgumentParser(
11
+ prog="spark-framework",
12
+ description="SparkFramework — Motor de pipelines Spark orientado a JSON",
13
+ formatter_class=argparse.RawDescriptionHelpFormatter,
14
+ epilog="""
15
+ Exemplos:
16
+ spark-framework examples/basic_parquet.json
17
+ spark-framework examples/iceberg_upsert.json --stop-spark
18
+ spark-framework tests/ingestion_csv_to_parquet.json
19
+ """,
20
+ )
21
+ parser.add_argument("config", help="Caminho para o arquivo JSON de configuracao")
22
+ parser.add_argument(
23
+ "--stop-spark",
24
+ action="store_true",
25
+ help="Encerra a SparkSession ao final",
26
+ )
27
+ return parser
28
+
29
+
30
+ def main() -> None:
31
+ parser = build_parser()
32
+ args = parser.parse_args()
33
+
34
+ fw = SparkFramework()
35
+ result = fw.run(args.config)
36
+
37
+ print(result.summary())
38
+
39
+ if result.validation_results:
40
+ print("\nValidacoes:")
41
+ for r in result.validation_results:
42
+ print(f" {r}")
43
+
44
+ if args.stop_spark:
45
+ fw.stop()
46
+
47
+ if not result.success:
48
+ sys.exit(1)
49
+
50
+
51
+ if __name__ == "__main__":
52
+ main()
File without changes
@@ -0,0 +1,154 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from dataclasses import dataclass, field
5
+ from pathlib import Path
6
+ from typing import Any, Dict, List, Optional
7
+
8
+
9
+ @dataclass
10
+ class SparkConfig:
11
+ app_name: str = "SparkFramework"
12
+ master: str = "local[*]"
13
+ configs: Dict[str, str] = field(default_factory=dict)
14
+
15
+ @classmethod
16
+ def from_dict(cls, data: Dict[str, Any]) -> SparkConfig:
17
+ return cls(
18
+ app_name=data.get("app_name", "SparkFramework"),
19
+ master=data.get("master", "local[*]"),
20
+ configs=data.get("configs", {}),
21
+ )
22
+
23
+
24
+ @dataclass
25
+ class InputConfig:
26
+ """Configuração da fonte de dados principal do pipeline."""
27
+
28
+ format: str
29
+ path: str
30
+ options: Dict[str, Any] = field(default_factory=dict)
31
+
32
+ @classmethod
33
+ def from_dict(cls, data: Dict[str, Any]) -> InputConfig:
34
+ return cls(
35
+ format=data["format"].lower(),
36
+ path=data["path"],
37
+ options=data.get("options", {}),
38
+ )
39
+
40
+
41
+ _TRANSFORMATION_META_KEYS = {"type", "skip_if_false"}
42
+
43
+
44
+ @dataclass
45
+ class TransformationConfig:
46
+ type: str
47
+ params: Dict[str, Any] = field(default_factory=dict)
48
+ skip_if_false: Optional[str] = None
49
+
50
+ @classmethod
51
+ def from_dict(cls, data: Dict[str, Any]) -> TransformationConfig:
52
+ return cls(
53
+ type=data["type"],
54
+ skip_if_false=data.get("skip_if_false"),
55
+ params={k: v for k, v in data.items() if k not in _TRANSFORMATION_META_KEYS},
56
+ )
57
+
58
+
59
+ @dataclass
60
+ class ValidationRule:
61
+ type: str
62
+ params: Dict[str, Any] = field(default_factory=dict)
63
+
64
+ @classmethod
65
+ def from_dict(cls, data: Dict[str, Any]) -> ValidationRule:
66
+ return cls(
67
+ type=data["type"],
68
+ params={k: v for k, v in data.items() if k != "type"},
69
+ )
70
+
71
+
72
+ @dataclass
73
+ class ValidationConfig:
74
+ on_failure: str = "fail" # fail | warn | skip
75
+ rules: List[ValidationRule] = field(default_factory=list)
76
+
77
+ @classmethod
78
+ def from_dict(cls, data: Dict[str, Any]) -> ValidationConfig:
79
+ return cls(
80
+ on_failure=data.get("on_failure", "fail"),
81
+ rules=[ValidationRule.from_dict(r) for r in data.get("rules", [])],
82
+ )
83
+
84
+
85
+ @dataclass
86
+ class OutputConfig:
87
+ """Configuração de um destino de escrita do pipeline.
88
+
89
+ O campo `columns` permite selecionar quais colunas serão escritas
90
+ neste destino específico, sem alterar o DataFrame das demais saídas.
91
+ Se omitido, todas as colunas são escritas.
92
+ """
93
+
94
+ format: str
95
+ path: str
96
+ mode: str = "overwrite" # append | overwrite | merge
97
+ partition_by: List[str] = field(default_factory=list)
98
+ columns: Optional[List[str]] = None
99
+ options: Dict[str, Any] = field(default_factory=dict)
100
+
101
+ @classmethod
102
+ def from_dict(cls, data: Dict[str, Any]) -> OutputConfig:
103
+ return cls(
104
+ format=data["format"].lower(),
105
+ path=data["path"],
106
+ mode=data.get("mode", "overwrite"),
107
+ partition_by=data.get("partition_by", []),
108
+ columns=data.get("columns"),
109
+ options=data.get("options", {}),
110
+ )
111
+
112
+
113
+ @dataclass
114
+ class PipelineConfig:
115
+ name: str
116
+ input: InputConfig
117
+ outputs: List[OutputConfig]
118
+ description: str = ""
119
+ spark: SparkConfig = field(default_factory=SparkConfig)
120
+ transformations: List[TransformationConfig] = field(default_factory=list)
121
+ validations: ValidationConfig = field(default_factory=ValidationConfig)
122
+
123
+ @classmethod
124
+ def from_file(cls, path: str) -> PipelineConfig:
125
+ from spark_framework.utils.includes import resolve_includes
126
+ content = Path(path).read_text(encoding="utf-8")
127
+ data = json.loads(content)
128
+ data = resolve_includes(data, Path(path).parent)
129
+ return cls.from_dict(data)
130
+
131
+ @classmethod
132
+ def from_dict(cls, data: Dict[str, Any]) -> PipelineConfig:
133
+ # Aceita "output" (objeto único) ou "outputs" (lista)
134
+ if "outputs" in data:
135
+ outputs = [OutputConfig.from_dict(o) for o in data["outputs"]]
136
+ elif "output" in data:
137
+ outputs = [OutputConfig.from_dict(data["output"])]
138
+ else:
139
+ raise ValueError(
140
+ "O JSON do pipeline precisa ter 'output' (objeto) ou 'outputs' (lista)."
141
+ )
142
+
143
+ return cls(
144
+ name=data["name"],
145
+ description=data.get("description", ""),
146
+ spark=SparkConfig.from_dict(data.get("spark", {})),
147
+ input=InputConfig.from_dict(data["input"]),
148
+ transformations=[
149
+ TransformationConfig.from_dict(t)
150
+ for t in data.get("transformations", [])
151
+ ],
152
+ validations=ValidationConfig.from_dict(data.get("validations", {})),
153
+ outputs=outputs,
154
+ )
@@ -0,0 +1,74 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+
5
+ from pyspark.sql import SparkSession
6
+
7
+ from spark_framework.core.config import SparkConfig
8
+
9
+
10
+ def _detect_environment() -> str:
11
+ """Detecta o ambiente de execução atual.
12
+
13
+ Returns:
14
+ "databricks" | "emr" | "dataproc" | "synapse" | "local"
15
+ """
16
+ if "DATABRICKS_RUNTIME_VERSION" in os.environ:
17
+ return "databricks"
18
+ if os.environ.get("EMR_CLUSTER_ID") or os.path.exists("/mnt/var/lib/info/job-flow.json"):
19
+ return "emr"
20
+ if os.environ.get("DATAPROC_IMAGE_VERSION") or os.environ.get("DATAPROC_CLUSTER_NAME"):
21
+ return "dataproc"
22
+ if os.environ.get("SYNAPSE_WORKSPACE_NAME") or os.environ.get("AZURE_DATABRICKS_ORG_ID"):
23
+ return "synapse"
24
+ return "local"
25
+
26
+
27
+ class SparkContextManager:
28
+ """Gerencia um singleton de SparkSession para toda a vida do pipeline.
29
+
30
+ Em Databricks reutiliza a sessão ativa do runtime — nunca cria uma nova.
31
+ Em outros ambientes (EMR, Dataproc, Synapse, local) cria/reutiliza via builder.
32
+ """
33
+
34
+ _session: SparkSession | None = None
35
+
36
+ @classmethod
37
+ def get_or_create(cls, config: SparkConfig) -> SparkSession:
38
+ if cls._session is not None:
39
+ return cls._session
40
+
41
+ env = _detect_environment()
42
+
43
+ if env == "databricks":
44
+ # No Databricks a sessão já existe; criar uma nova causaria erro.
45
+ cls._session = SparkSession.getActiveSession()
46
+ if cls._session is None:
47
+ # Fallback improvável, mas seguro
48
+ cls._session = SparkSession.builder.getOrCreate()
49
+ else:
50
+ builder = SparkSession.builder.appName(config.app_name)
51
+
52
+ # Master só faz sentido fora de clusters gerenciados
53
+ if env == "local":
54
+ builder = builder.master(config.master)
55
+
56
+ for key, value in config.configs.items():
57
+ builder = builder.config(key, value)
58
+
59
+ cls._session = builder.getOrCreate()
60
+ cls._session.sparkContext.setLogLevel("WARN")
61
+
62
+ return cls._session
63
+
64
+ @classmethod
65
+ def stop(cls) -> None:
66
+ env = _detect_environment()
67
+ if cls._session is not None and env != "databricks":
68
+ cls._session.stop()
69
+ cls._session = None
70
+
71
+ @classmethod
72
+ def current_environment(cls) -> str:
73
+ """Retorna o ambiente detectado (útil para logs e diagnóstico)."""
74
+ return _detect_environment()
@@ -0,0 +1,145 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ from pyspark.sql import DataFrame, SparkSession
7
+ from pyspark.sql import functions as F
8
+
9
+ from spark_framework.core.config import OutputConfig, PipelineConfig
10
+ from spark_framework.core.context import SparkContextManager
11
+ from spark_framework.io.factory import ReaderFactory, WriterFactory
12
+ from spark_framework.transform.engine import TransformationEngine
13
+ from spark_framework.validation.base import ValidationResult
14
+ from spark_framework.validation.engine import ValidationEngine
15
+ from spark_framework.utils.logger import logger
16
+
17
+
18
+ @dataclass
19
+ class PipelineResult:
20
+ pipeline_name: str
21
+ success: bool
22
+ rows_read: int = 0
23
+ rows_written: int = 0
24
+ validation_results: List[ValidationResult] = field(default_factory=list)
25
+ error: Optional[str] = None
26
+ output_df: Optional[DataFrame] = None # df após transformações; disponível quando input_df é injetado
27
+
28
+ def summary(self) -> str:
29
+ if not self.success:
30
+ return f"[FAIL] '{self.pipeline_name}': {self.error}"
31
+ passed = sum(1 for r in self.validation_results if r.passed)
32
+ total = len(self.validation_results)
33
+ return (
34
+ f"[OK] '{self.pipeline_name}' | "
35
+ f"lidos={self.rows_read} | "
36
+ f"escritos={self.rows_written} | "
37
+ f"validacoes={passed}/{total}"
38
+ )
39
+
40
+
41
+ class Pipeline:
42
+ """Orquestra o fluxo: leitura → transformação → validação → escrita."""
43
+
44
+ def __init__(
45
+ self,
46
+ config: PipelineConfig,
47
+ transform_engine: Optional[TransformationEngine] = None,
48
+ validation_engine: Optional[ValidationEngine] = None,
49
+ input_df: Optional[DataFrame] = None,
50
+ columns: Optional[Dict[str, Any]] = None,
51
+ ) -> None:
52
+ self.config = config
53
+ self._transform_engine = transform_engine or TransformationEngine()
54
+ self._validation_engine = validation_engine or ValidationEngine()
55
+ self._input_df = input_df
56
+ self._columns: Dict[str, Any] = columns or {}
57
+
58
+ @classmethod
59
+ def from_file(cls, path: str) -> Pipeline:
60
+ return cls(PipelineConfig.from_file(path))
61
+
62
+ @classmethod
63
+ def from_dict(cls, data: dict) -> Pipeline:
64
+ return cls(PipelineConfig.from_dict(data))
65
+
66
+ def run(self) -> PipelineResult:
67
+ log = logger.bind(pipeline=self.config.name)
68
+ log.info("Pipeline iniciado")
69
+
70
+ try:
71
+ spark = SparkContextManager.get_or_create(self.config.spark)
72
+
73
+ if self._input_df is not None:
74
+ df = self._input_df
75
+ rows_read = 0
76
+ log.info("Input df injetado externamente", colunas=len(df.columns))
77
+ else:
78
+ df = ReaderFactory.create(spark, self.config.input).read()
79
+ df = df.withColumn("ingestion_ts", F.current_timestamp())
80
+ rows_read = df.count()
81
+ log.info(
82
+ "Leitura concluida",
83
+ linhas=rows_read,
84
+ formato=self.config.input.format,
85
+ )
86
+
87
+ for col_name, value in self._columns.items():
88
+ df = df.withColumn(col_name, F.lit(value))
89
+ if self._columns:
90
+ log.info("Colunas injetadas", colunas=list(self._columns))
91
+
92
+ df = self._transform_engine.apply(df, self.config.transformations)
93
+ log.info("Transformacoes aplicadas")
94
+
95
+ validation_results = self._validation_engine.validate(
96
+ df, self.config.validations
97
+ )
98
+
99
+ rows_written = df.count()
100
+ self._write_outputs(spark, df, log)
101
+ log.info("Pipeline concluido", linhas_escritas=rows_written)
102
+
103
+ return PipelineResult(
104
+ pipeline_name=self.config.name,
105
+ success=True,
106
+ rows_read=rows_read,
107
+ rows_written=rows_written,
108
+ validation_results=validation_results,
109
+ output_df=df,
110
+ )
111
+
112
+ except Exception as exc:
113
+ log.error("Pipeline falhou", error=str(exc))
114
+ return PipelineResult(
115
+ pipeline_name=self.config.name,
116
+ success=False,
117
+ error=str(exc),
118
+ )
119
+
120
+ def _write_outputs(
121
+ self, spark: SparkSession, df: DataFrame, log
122
+ ) -> None:
123
+ for output in self.config.outputs:
124
+ output_df = self._project_columns(df, output)
125
+ log.info(
126
+ "Escrevendo output",
127
+ formato=output.format,
128
+ path=output.path,
129
+ modo=output.mode,
130
+ colunas=output.columns or "todas",
131
+ )
132
+ WriterFactory.create(spark, output).write(output_df)
133
+
134
+ @staticmethod
135
+ def _project_columns(df: DataFrame, output: OutputConfig) -> DataFrame:
136
+ """Aplica seleção de colunas se o output tiver 'columns' definido."""
137
+ if not output.columns:
138
+ return df
139
+ missing = [c for c in output.columns if c not in df.columns]
140
+ if missing:
141
+ raise ValueError(
142
+ f"Colunas inexistentes no output '{output.path}': {missing}. "
143
+ f"Colunas disponiveis: {df.columns}"
144
+ )
145
+ return df.select(*output.columns)
@@ -0,0 +1,164 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import Any, Dict, Optional
6
+
7
+ from pyspark.sql import DataFrame
8
+
9
+ from spark_framework.core.config import PipelineConfig, SparkConfig
10
+ from spark_framework.utils.template import apply_template
11
+ from spark_framework.utils.includes import resolve_includes
12
+ from spark_framework.core.context import SparkContextManager
13
+ from spark_framework.core.pipeline import Pipeline, PipelineResult
14
+ from spark_framework.io.factory import ReaderFactory, WriterFactory
15
+ from spark_framework.io.base import BaseReader, BaseWriter
16
+ from spark_framework.transform.base import BaseTransformation
17
+ from spark_framework.transform.engine import TransformationEngine
18
+ from spark_framework.validation.base import BaseValidator
19
+ from spark_framework.validation.engine import ValidationEngine
20
+
21
+
22
+ class SparkFramework:
23
+ """Ponto de entrada principal para uso do framework como biblioteca.
24
+
25
+ Gerencia uma única SparkSession compartilhada entre múltiplas execuções.
26
+ Permite registrar formatos, transformações e validators customizados que
27
+ ficam disponíveis em todos os pipelines executados pela mesma instância.
28
+
29
+ Uso básico:
30
+ fw = SparkFramework()
31
+ fw.run("pipeline_clientes.json")
32
+ fw.run("pipeline_pedidos.json")
33
+ fw.stop()
34
+
35
+ Com configurações Spark personalizadas:
36
+ fw = SparkFramework(spark={"master": "yarn", "app_name": "MeuJob"})
37
+
38
+ Com extensões customizadas:
39
+ fw = SparkFramework()
40
+ fw.register_reader("delta", DeltaReader)
41
+ fw.register_transformation("normalize", NormalizeTransformation)
42
+ fw.run("config.json")
43
+ """
44
+
45
+ def __init__(self, spark: Optional[Dict[str, Any]] = None) -> None:
46
+ self._spark_config = SparkConfig.from_dict(spark or {})
47
+ self._transform_engine = TransformationEngine()
48
+ self._validation_engine = ValidationEngine()
49
+ SparkContextManager.get_or_create(self._spark_config)
50
+
51
+ # ------------------------------------------------------------------
52
+ # Execução de pipelines
53
+ # ------------------------------------------------------------------
54
+
55
+ def run(
56
+ self,
57
+ config_path: str,
58
+ input_df: Optional[DataFrame] = None,
59
+ columns: Optional[Dict[str, Any]] = None,
60
+ params: Optional[Dict[str, Any]] = None,
61
+ ) -> PipelineResult:
62
+ """Executa um pipeline a partir de um arquivo JSON.
63
+
64
+ Args:
65
+ config_path: caminho para o JSON de configuração.
66
+ input_df: DataFrame de entrada; quando fornecido substitui o 'input'
67
+ declarado no JSON e não adiciona ingestion_ts automaticamente.
68
+ columns: Colunas literais a injetar no df antes das transformações,
69
+ ex: {"param_tipo_ativo": "NC", "param_registradora": "CERC"}.
70
+ params: Valores de runtime substituídos como {chave} no JSON antes do
71
+ parse. Listas viram SQL IN (ex: 'a', 'b'); booleanos viram
72
+ "true"/"" (vazio = falsy dispara skip_if_false).
73
+ """
74
+ config = self._load_config(config_path, params)
75
+ self._apply_spark_override(config)
76
+ return self._execute(config, input_df=input_df, columns=columns)
77
+
78
+ def run_from_dict(
79
+ self,
80
+ config: Dict[str, Any],
81
+ input_df: Optional[DataFrame] = None,
82
+ columns: Optional[Dict[str, Any]] = None,
83
+ params: Optional[Dict[str, Any]] = None,
84
+ ) -> PipelineResult:
85
+ """Executa um pipeline a partir de um dicionário Python."""
86
+ pipeline_config = self._load_config_from_dict(config, params)
87
+ self._apply_spark_override(pipeline_config)
88
+ return self._execute(pipeline_config, input_df=input_df, columns=columns)
89
+
90
+ # ------------------------------------------------------------------
91
+ # Registro de extensões
92
+ # ------------------------------------------------------------------
93
+
94
+ def register_reader(self, format_name: str, reader_cls: type[BaseReader]) -> None:
95
+ """Registra um leitor customizado para um novo formato."""
96
+ ReaderFactory.register(format_name, reader_cls)
97
+
98
+ def register_writer(self, format_name: str, writer_cls: type[BaseWriter]) -> None:
99
+ """Registra um escritor customizado para um novo formato."""
100
+ WriterFactory.register(format_name, writer_cls)
101
+
102
+ def register_transformation(
103
+ self, name: str, transformation_cls: type[BaseTransformation]
104
+ ) -> None:
105
+ """Registra uma transformação customizada disponível via JSON."""
106
+ self._transform_engine.register(name, transformation_cls)
107
+
108
+ def register_validator(
109
+ self, name: str, validator_cls: type[BaseValidator]
110
+ ) -> None:
111
+ """Registra um validator customizado disponível via JSON."""
112
+ self._validation_engine.register(name, validator_cls)
113
+
114
+ # ------------------------------------------------------------------
115
+ # Ciclo de vida
116
+ # ------------------------------------------------------------------
117
+
118
+ def stop(self) -> None:
119
+ """Encerra a SparkSession."""
120
+ SparkContextManager.stop()
121
+
122
+ # ------------------------------------------------------------------
123
+ # Interno
124
+ # ------------------------------------------------------------------
125
+
126
+ def _execute(
127
+ self,
128
+ config: PipelineConfig,
129
+ input_df: Optional[DataFrame] = None,
130
+ columns: Optional[Dict[str, Any]] = None,
131
+ ) -> PipelineResult:
132
+ pipeline = Pipeline(
133
+ config,
134
+ transform_engine=self._transform_engine,
135
+ validation_engine=self._validation_engine,
136
+ input_df=input_df,
137
+ columns=columns,
138
+ )
139
+ return pipeline.run()
140
+
141
+ def _load_config(self, path: str, params: Optional[Dict[str, Any]]) -> PipelineConfig:
142
+ raw = Path(path).read_text(encoding="utf-8")
143
+ if params:
144
+ raw = apply_template(raw, params)
145
+ data = json.loads(raw)
146
+ data = resolve_includes(data, Path(path).parent, params)
147
+ return PipelineConfig.from_dict(data)
148
+
149
+ def _load_config_from_dict(
150
+ self, config: Dict[str, Any], params: Optional[Dict[str, Any]]
151
+ ) -> PipelineConfig:
152
+ if params:
153
+ raw = apply_template(json.dumps(config), params)
154
+ config = json.loads(raw)
155
+ config = resolve_includes(config, Path.cwd(), params)
156
+ return PipelineConfig.from_dict(config)
157
+
158
+ def _apply_spark_override(self, config: PipelineConfig) -> None:
159
+ """Garante que configs Spark do framework prevalecem sobre o JSON."""
160
+ if self._spark_config.app_name != "SparkFramework":
161
+ config.spark.app_name = self._spark_config.app_name
162
+ if self._spark_config.master != "local[*]":
163
+ config.spark.master = self._spark_config.master
164
+ config.spark.configs.update(self._spark_config.configs)
File without changes
@@ -0,0 +1,25 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ from pyspark.sql import DataFrame, SparkSession
4
+
5
+ from spark_framework.core.config import InputConfig, OutputConfig
6
+
7
+
8
+ class BaseReader(ABC):
9
+ def __init__(self, spark: SparkSession, config: InputConfig) -> None:
10
+ self.spark = spark
11
+ self.config = config
12
+
13
+ @abstractmethod
14
+ def read(self) -> DataFrame:
15
+ ...
16
+
17
+
18
+ class BaseWriter(ABC):
19
+ def __init__(self, spark: SparkSession, config: OutputConfig) -> None:
20
+ self.spark = spark
21
+ self.config = config
22
+
23
+ @abstractmethod
24
+ def write(self, df: DataFrame) -> None:
25
+ ...
@@ -0,0 +1,29 @@
1
+ from pyspark.sql import DataFrame
2
+
3
+ from spark_framework.io.base import BaseReader, BaseWriter
4
+
5
+ _CSV_READ_DEFAULTS = {
6
+ "header": "true",
7
+ "inferSchema": "true",
8
+ "encoding": "UTF-8",
9
+ }
10
+
11
+ _CSV_WRITE_DEFAULTS = {
12
+ "header": "true",
13
+ "encoding": "UTF-8",
14
+ }
15
+
16
+
17
+ class CsvReader(BaseReader):
18
+ def read(self) -> DataFrame:
19
+ options = {**_CSV_READ_DEFAULTS, **self.config.options}
20
+ return self.spark.read.options(**options).csv(self.config.path)
21
+
22
+
23
+ class CsvWriter(BaseWriter):
24
+ def write(self, df: DataFrame) -> None:
25
+ options = {**_CSV_WRITE_DEFAULTS, **self.config.options}
26
+ writer = df.write.mode(self.config.mode).options(**options)
27
+ if self.config.partition_by:
28
+ writer = writer.partitionBy(*self.config.partition_by)
29
+ writer.csv(self.config.path)