psr-factory 5.0.0b21__py3-none-win_amd64.whl → 5.0.0b69__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
psr/runner/runner.py CHANGED
@@ -10,7 +10,7 @@ import shutil
10
10
  import socket
11
11
  import subprocess
12
12
  from types import ModuleType
13
- from typing import Dict, List, Optional, Tuple, Union
13
+ from typing import Any, Dict, List, Optional, Tuple, Union
14
14
  import warnings
15
15
 
16
16
 
@@ -90,18 +90,10 @@ def _get_nproc(specified: int, available: int) -> int:
90
90
  else:
91
91
  return available
92
92
 
93
- def _write_mpi_settings(mpi_file_path: Union[str, pathlib.Path], cluster_settings: Optional[Union[int, bool, Dict[str, int]]]):
93
+ def _write_mpi_settings(mpi_file_path: Union[str, pathlib.Path, Any], cluster_settings: Optional[Union[int, bool, Dict[str, int]]]):
94
94
  if cluster_settings is not None:
95
- mpi_file_path = str(mpi_file_path)
96
95
  available_cpu = _get_available_cpu()
97
- if isinstance(cluster_settings, int):
98
- computer_name = _get_host_name()
99
- specified_cpu_number = cluster_settings
100
- nproc = _get_nproc(specified_cpu_number, available_cpu)
101
- cluster_settings = {computer_name: nproc}
102
- elif isinstance(cluster_settings, dict):
103
- pass
104
- elif isinstance(cluster_settings, bool):
96
+ if isinstance(cluster_settings, bool):
105
97
  # Rewrite with default settings.
106
98
  if cluster_settings:
107
99
  computer_name = _get_host_name()
@@ -109,6 +101,13 @@ def _write_mpi_settings(mpi_file_path: Union[str, pathlib.Path], cluster_setting
109
101
  cluster_settings = {computer_name: nproc}
110
102
  else:
111
103
  cluster_settings = None
104
+ elif isinstance(cluster_settings, int):
105
+ computer_name = _get_host_name()
106
+ specified_cpu_number = cluster_settings
107
+ nproc = _get_nproc(specified_cpu_number, available_cpu)
108
+ cluster_settings = {computer_name: nproc}
109
+ elif isinstance(cluster_settings, dict):
110
+ pass
112
111
  else:
113
112
  raise ValueError("Invalid cluster settings type")
114
113
  else:
@@ -117,22 +116,29 @@ def _write_mpi_settings(mpi_file_path: Union[str, pathlib.Path], cluster_setting
117
116
  cluster_settings = {computer_name: nproc}
118
117
 
119
118
  if isinstance(cluster_settings, dict):
120
- with open(mpi_file_path, 'w') as f:
121
- for computer, nproc in cluster_settings.items():
122
- f.write(f"{computer}:{nproc}\n")
123
-
119
+ if isinstance(mpi_file_path, (str, pathlib.Path)):
120
+ f = open(mpi_file_path, 'w')
121
+ must_close = True
122
+ else:
123
+ f = open(mpi_file_path.name, 'w')
124
+ must_close = False
125
+ for computer, nproc in cluster_settings.items():
126
+ f.write(f"{computer}:{nproc}\n")
127
+ if must_close:
128
+ f.close()
124
129
 
125
130
 
126
131
  def run_sddp(case_path: Union[str, pathlib.Path], sddp_path: Union[str, pathlib.Path], **kwargs):
127
132
  case_path = os.path.abspath(str(case_path))
128
133
  sddp_path = str(sddp_path)
129
134
  parallel_run = kwargs.get("parallel_run", True)
130
- cluster_settings: Optional[Union[int, bool, Dict[str, int]]] = kwargs.get("cluster_settings", False)
135
+ cluster_settings: Optional[Union[int, bool, Dict[str, int]]] = kwargs.get("cluster_settings", None)
131
136
  dry_run = kwargs.get("dry_run", False)
132
137
  show_progress = kwargs.get("show_progress", False)
133
138
  extra_args = " ".join(kwargs.get("extra_args", ()))
134
139
  exec_mode = kwargs.get("_mode", None)
135
140
  mpi_path = kwargs.get("mpi_path", __default_mpi_path)
141
+ env = kwargs.get("env", {})
136
142
 
137
143
  sddp_path_full = _get_sddp_executable_parent_path(sddp_path)
138
144
  # Append last / if missing.
@@ -144,14 +150,26 @@ def run_sddp(case_path: Union[str, pathlib.Path], sddp_path: Union[str, pathlib.
144
150
 
145
151
  major, minor, patch, tag = _get_semver_version(get_sddp_version(sddp_path))
146
152
 
147
- with change_cwd(sddp_path_full):
153
+ temp_folder = os.path.join(os.getenv("TEMP") or os.getenv("TMPDIR") or os.getenv("TMP") or "/tmp", "")
154
+ with (psr.psrfcommon.tempfile.CreateTempFile(temp_folder, "mpd_sddp", "", ".hosts", False) as mpi_temp_file,
155
+ change_cwd(sddp_path_full)):
156
+
148
157
  # Write MPI settings if required
149
158
  if parallel_run and cluster_settings is not None:
150
- if major >= 18:
159
+ if major >= 18 and minor >= 0 and patch >= 7:
160
+ _write_mpi_settings(mpi_temp_file, cluster_settings)
161
+ extra_args = extra_args + f" --hostsfile=\"{mpi_temp_file.name}\""
162
+ if dry_run:
163
+ print("Using temporary mpi settings file:", mpi_temp_file.name)
164
+ mpi_written = True
165
+ elif major >= 18:
151
166
  mpi_file_path = os.path.join(sddp_path_full, "mpd_sddp.hosts")
167
+ mpi_written = False
152
168
  else:
153
169
  mpi_file_path = os.path.join(sddp_path_full, "mpd.hosts")
154
- _write_mpi_settings(mpi_file_path, cluster_settings)
170
+ mpi_written = False
171
+ if not mpi_written:
172
+ _write_mpi_settings(mpi_file_path, cluster_settings)
155
173
 
156
174
  if parallel_run:
157
175
  if os.name == 'nt':
@@ -173,7 +191,9 @@ def run_sddp(case_path: Union[str, pathlib.Path], sddp_path: Union[str, pathlib.
173
191
  cmd = f'./sddp {mode_arg} -path "{case_path_last_slash}" {extra_args}'
174
192
 
175
193
  if os.name != "nt":
176
- os.environ["LD_LIBRARY_PATH"] = os.path.abspath(sddp_path_full)
194
+ env["LD_LIBRARY_PATH"] = os.path.realpath(sddp_path_full)
195
+ env["MPI_PATH"] = os.path.realpath(mpi_path)
196
+ kwargs["env"] = env
177
197
  exec_cmd(cmd, **kwargs)
178
198
 
179
199
 
@@ -199,6 +219,9 @@ def _get_sddp_executable_parent_path(sddp_path: Union[str, pathlib.Path]) -> str
199
219
  return model_path
200
220
  else:
201
221
  return os.path.join(sddp_path, "Oper")
222
+ else:
223
+ # solve symlinks, if needed
224
+ sddp_path = os.path.realpath(sddp_path)
202
225
  return sddp_path
203
226
 
204
227
  def _get_optgen_executable_parent_path(optgen_path: Union[str, pathlib.Path]) -> str:
@@ -208,8 +231,23 @@ def _get_optgen_executable_parent_path(optgen_path: Union[str, pathlib.Path]) ->
208
231
  return model_path
209
232
  else:
210
233
  return os.path.join(optgen_path, "Model")
234
+ else:
235
+ # solve symlinks, if needed
236
+ optgen_path = os.path.realpath(optgen_path)
211
237
  return optgen_path
212
238
 
239
+ def _get_optmain_executable_parent_path(optmain_path: Union[str, pathlib.Path]) -> str:
240
+ if os.name == 'nt':
241
+ model_path = os.path.join(optmain_path, "models", "optmain")
242
+ if os.path.exists(model_path):
243
+ return model_path
244
+ else:
245
+ return os.path.join(optmain_path, "Model")
246
+ else:
247
+ # solve symlinks, if needed
248
+ optmain_path = os.path.realpath(optmain_path)
249
+ return optmain_path
250
+
213
251
 
214
252
  def get_sddp_version(sddp_path: Union[str, pathlib.Path]) -> str:
215
253
  sddp_path = str(sddp_path)
@@ -220,8 +258,13 @@ def get_sddp_version(sddp_path: Union[str, pathlib.Path]) -> str:
220
258
  command = [os.path.join(sddp_path_full, "sddp"), "ver"]
221
259
 
222
260
  if os.name != "nt":
223
- os.environ["LD_LIBRARY_PATH"] = os.path.abspath(sddp_path_full)
224
- sub = subprocess.run(command, stdout=subprocess.PIPE, check=False)
261
+ env = {
262
+ "LD_LIBRARY_PATH": os.path.realpath(sddp_path_full)
263
+ }
264
+ else:
265
+ env = {}
266
+
267
+ sub = subprocess.run(command, stdout=subprocess.PIPE, check=False, env=env)
225
268
  output = sub.stdout.decode("utf-8").strip()
226
269
  return output.split()[2]
227
270
 
@@ -289,6 +332,17 @@ def run_optgen_cleanup(case_path: Union[str, pathlib.Path], optgen_path: Union[s
289
332
  kwargs["_mode"] = "clean"
290
333
  run_optgen(case_path, optgen_path, sddp_path, **kwargs)
291
334
 
335
+
336
+ def run_optmain(case_path: Union[str, pathlib.Path], optmain_path: Union[str, pathlib.Path], **kwargs):
337
+ case_path = os.path.abspath(str(case_path)).replace("\\", "/") + "/"
338
+ optmain_path = str(optmain_path)
339
+ optmain_path_full = _get_optmain_executable_parent_path(optmain_path)
340
+
341
+ with change_cwd(optmain_path_full):
342
+ cmd = f'optmain {case_path}'
343
+ exec_cmd(cmd, **kwargs)
344
+
345
+
292
346
  def run_psrio(case_path, sddp_path: str, **kwargs):
293
347
  recipe_script = kwargs.get('r', kwargs.get('recipes', False))
294
348
  output_path = kwargs.get('o', kwargs.get('output', False))
@@ -595,6 +649,20 @@ def run_tslconsole(tsl_path: Union[str, pathlib.Path], script_path: Union[str, p
595
649
  cmd = f'TimeSeriesConsole.exe "{str(script_path)}"'
596
650
  exec_cmd(cmd, **kwargs)
597
651
 
652
+ def run_tsl_generate_inflow_from_external_natural(case_path: Union[str, pathlib.Path], tsl_path: Union[str, pathlib.Path], **kwargs):
653
+ commands = ["generate_inflow_from_external_natural"]
654
+ case_path = os.path.abspath(str(case_path))
655
+ tsl_path = str(tsl_path)
656
+ _run_tslconsole_command(tsl_path, case_path, commands)
657
+
658
+
659
+ def run_tsl_generate_inflow_from_external_incremental(case_path: Union[str, pathlib.Path], tsl_path: Union[str, pathlib.Path], **kwargs):
660
+ commands = ["generate_inflow_from_external_incremental"]
661
+ case_path = os.path.abspath(str(case_path))
662
+ tsl_path = str(tsl_path)
663
+ _run_tslconsole_command(tsl_path, case_path, commands)
664
+
665
+
598
666
  def run_tsl(case_path: Union[str, pathlib.Path], tsl_path: Union[str, pathlib.Path], base_type: str, **kwargs):
599
667
  if os.name != 'nt':
600
668
  raise NotImplementedError("Running TimeSeriesLab is only available on Windows")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: psr-factory
3
- Version: 5.0.0b21
3
+ Version: 5.0.0b69
4
4
  Summary: PSR database management module.
5
5
  Author-email: "PSR Inc." <psrfactory@psr-inc.com>
6
6
  License-Expression: MIT
@@ -15,6 +15,7 @@ Classifier: Programming Language :: Python :: 3.10
15
15
  Classifier: Programming Language :: Python :: 3.11
16
16
  Classifier: Programming Language :: Python :: 3.12
17
17
  Classifier: Programming Language :: Python :: 3.13
18
+ Classifier: Programming Language :: Python :: 3.14
18
19
  Classifier: Topic :: Software Development
19
20
  Classifier: Topic :: Scientific/Engineering
20
21
  Classifier: Operating System :: Microsoft :: Windows
@@ -28,11 +29,7 @@ Requires-Dist: pandas; extra == "pandas"
28
29
  Provides-Extra: polars
29
30
  Requires-Dist: polars; extra == "polars"
30
31
  Provides-Extra: cloud
31
- Requires-Dist: zeep; extra == "cloud"
32
- Requires-Dist: filelock; extra == "cloud"
33
- Requires-Dist: pefile; extra == "cloud"
34
- Requires-Dist: boto3; extra == "cloud"
35
- Requires-Dist: botocore; extra == "cloud"
32
+ Requires-Dist: psr-cloud; extra == "cloud"
36
33
  Provides-Extra: execqueue-client
37
34
  Requires-Dist: requests; extra == "execqueue-client"
38
35
  Provides-Extra: execqueue-server
@@ -41,17 +38,10 @@ Requires-Dist: Flask; extra == "execqueue-server"
41
38
  Requires-Dist: python-ulid; extra == "execqueue-server"
42
39
  Requires-Dist: sqlalchemy; extra == "execqueue-server"
43
40
  Requires-Dist: python-dotenv; extra == "execqueue-server"
44
- Requires-Dist: pefile; extra == "execqueue-server"
45
- Requires-Dist: zeep; extra == "execqueue-server"
46
- Requires-Dist: filelock; extra == "execqueue-server"
47
- Requires-Dist: requests; extra == "execqueue-server"
41
+ Requires-Dist: psr-cloud; extra == "execqueue-server"
48
42
  Provides-Extra: all
49
43
  Requires-Dist: pandas; extra == "all"
50
44
  Requires-Dist: polars; extra == "all"
51
45
  Requires-Dist: psutil; extra == "all"
52
- Requires-Dist: zeep; extra == "all"
53
- Requires-Dist: filelock; extra == "all"
54
- Requires-Dist: pefile; extra == "all"
55
- Requires-Dist: boto3; extra == "all"
56
- Requires-Dist: botocore; extra == "all"
46
+ Requires-Dist: psr-cloud; extra == "all"
57
47
  Dynamic: license-file
@@ -0,0 +1,33 @@
1
+ psr/apps/__init__.py,sha256=frSq1WIy5vIdU21xJIGX7U3XoAZRj0pcQmFb-R00b7I,228
2
+ psr/apps/apps.py,sha256=V8Ewht7P1I-3sSkV3dnbxbLjF2slxPjcmtzmVaLjiNY,6746
3
+ psr/apps/version.py,sha256=vs459L6JsatAkUxna7BNG-vMCaXpO1Ye8c1bmkEx4U4,194
4
+ psr/execqueue/client.py,sha256=cQ2VK-jhalIiN4V6989jWpICnRhqaoFTl50SZNl92hw,5540
5
+ psr/execqueue/config.py,sha256=F8sp-JGeoRspQRR63SjSKV5wDz0OVGnA-cNm1UDYHBY,1693
6
+ psr/execqueue/db.py,sha256=T3EWiK_2USmgNKTiVaavbqnS-EklCCZKsScOtdD6FgM,10853
7
+ psr/execqueue/server.py,sha256=0N5_ekXj1h5A1QyI4xPmTDpYlAo-qQfLWAa-TL0syWI,28781
8
+ psr/execqueue/watcher.py,sha256=R1dyXJ-OYn_QjqdItBwbLJZQ2LcbtdHqnRaYkyphi4w,5637
9
+ psr/factory/__init__.py,sha256=SDxMzOm1aV-IMUXaahBqoutgc1lBZ69rE4YfLaROy1w,212
10
+ psr/factory/api.py,sha256=PK0OpX4AtHE2u25vmaornVR4XU5jbwiQiTGlAlVPvao,115062
11
+ psr/factory/factory.dll,sha256=QHD1k65DWU4jzQSOjw4DQVgR4IxF-t_Q6IAr98TMOOo,20277584
12
+ psr/factory/factory.pmd,sha256=wPTJucfpUxc7vyeldE-caD4TqJiT_dCbl-DY2zOkq8c,252791
13
+ psr/factory/factory.pmk,sha256=WrSXrK3zR_7v4cOVppRGWehnLhnZFSFWHnjyA9OUe00,607612
14
+ psr/factory/factorylib.py,sha256=-i_3j_dkCKZA_hpCePD7_qdGOqD8o89mwzqKmvya_EI,31226
15
+ psr/factory/libcurl-x64.dll,sha256=6WGBmqX4q_eD8Vc0E2VpCvVrFV3W7TQoaKqSdbhXBu0,5313096
16
+ psr/factory/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
+ psr/factory/samples/__init__.py,sha256=xxOch5Fokzjy571a6OHD87FWM17qKgvfcbr8xn-n36I,80
18
+ psr/factory/samples/sddp_case01.py,sha256=h0iVqswPeOzauikiHjwUiPRQbmIQjECPcGkX54G1Clo,5060
19
+ psr/factory/samples/sddp_case21.py,sha256=5DNIEu5aMhf8QOOJT-erdBmtozZukJA3k45C0nbfxYo,8791
20
+ psr/outputs/__init__.py,sha256=Rs2MCBm2fUjwm9CAjZU-2Ll7KUsVEbW16mNY1Xnx-pI,193
21
+ psr/outputs/outputs.py,sha256=gzmKl9Ma6K9gjpTTBc45p-IT7DgEX-PzRqtsQSWfja0,6727
22
+ psr/outputs/resample.py,sha256=6fVBWOzxu3utkavGS5EIUd0QpSZm6kc-J1Qv8Bk0X6o,14276
23
+ psr/psrfcommon/__init__.py,sha256=WXR560XQllIjtFpWd0jiJEbUAQIyh5-6lwj-42_J95c,200
24
+ psr/psrfcommon/psrfcommon.py,sha256=LOuojeKX51eCMcPwpfDgRnGlK6WpS5UwDxQajbdRb5I,1571
25
+ psr/psrfcommon/tempfile.py,sha256=5S13wa2DCLYTUdwbLm_KMBRnDRJ0WDlu8GO2BmZoNdg,3939
26
+ psr/runner/__init__.py,sha256=kI9HDX-B_LMQJUHHylFHas2rNpWfNNa0pZXoIvX_Alw,230
27
+ psr/runner/runner.py,sha256=bgxZAvlgroe_F9QCInp9XbTVfSMSwjbpLEei5FeAxao,30352
28
+ psr/runner/version.py,sha256=mch2Y8anSXGMn9w72Z78PhSRhOyn55EwaoLAYhY4McE,194
29
+ psr_factory-5.0.0b69.dist-info/licenses/LICENSE.txt,sha256=N6mqZK2Ft3iXGHj-by_MHC_dJo9qwn0URjakEPys3H4,1089
30
+ psr_factory-5.0.0b69.dist-info/METADATA,sha256=-1mvIhGhrAzHEVozKeJcccVfT1dvU7dCf5koKzWIvBA,1914
31
+ psr_factory-5.0.0b69.dist-info/WHEEL,sha256=ZjXRCNaQ9YSypEK2TE0LRB0sy2OVXSszb4Sx1XjM99k,97
32
+ psr_factory-5.0.0b69.dist-info/top_level.txt,sha256=Jb393O96WQk3b5D1gMcrZBLKJJgZpzNjTPoldUi00ck,4
33
+ psr_factory-5.0.0b69.dist-info/RECORD,,
psr/cloud/__init__.py DELETED
@@ -1,7 +0,0 @@
1
- # PSR Cloud. Copyright (C) PSR, Inc - All Rights Reserved
2
- # Unauthorized copying of this file, via any medium is strictly prohibited
3
- # Proprietary and confidential
4
-
5
- from .cloud import *
6
- from .data import *
7
- from .version import __version__
psr/cloud/aws.py DELETED
@@ -1,256 +0,0 @@
1
- import os
2
- import tempfile
3
- import zipfile
4
- from typing import Dict, List, Optional
5
-
6
- import boto3
7
- from botocore.exceptions import ClientError
8
-
9
-
10
- def _get_region(url):
11
- """Extract the region from the S3 URL."""
12
- if url:
13
- parts = url.split(".")
14
- return parts[0]
15
- return None
16
-
17
-
18
- def upload_file_to_s3(
19
- s3_client, bucket_name, file_path, object_name=None, extra_args=None
20
- ):
21
- """Upload a file to an S3 bucket using a provided S3 client.
22
-
23
- :param s3_client: Initialized S3 client.
24
- :param bucket_name: Name of the S3 bucket.
25
- :param file_path: Path to the file to upload.
26
- :param object_name: S3 object name. If not specified, file_path's basename is used.
27
- :param extra_args: A dictionary of extra arguments to pass to S3's upload_file.
28
- :return: True if file was uploaded, else False.
29
- """
30
- if object_name is None:
31
- object_name = os.path.basename(file_path)
32
-
33
- try:
34
- s3_client.upload_file(file_path, bucket_name, object_name, ExtraArgs=extra_args)
35
- return True
36
- except ClientError as e:
37
- print(f"Error uploading file: {e}")
38
- return False
39
-
40
-
41
- def upload_case_to_s3(
42
- files: List[str],
43
- repository_id: str,
44
- cluster_name: str,
45
- checksums: Optional[Dict[str, str]] = None,
46
- access: Optional[str] = None,
47
- secret: Optional[str] = None,
48
- session_token: Optional[str] = None,
49
- bucket_name: Optional[str] = None,
50
- url: Optional[str] = None,
51
- zip_compress: bool = False,
52
- compress_zip_name: str = None,
53
- ):
54
- """Upload files to an S3 bucket."""
55
-
56
- region = _get_region(url)
57
-
58
- if not region or not access or not secret or not session_token or not bucket_name:
59
- raise ValueError("Unable to set up AWS connection.")
60
-
61
- s3_client = boto3.client(
62
- "s3",
63
- aws_access_key_id=access,
64
- aws_secret_access_key=secret,
65
- aws_session_token=session_token,
66
- region_name=region,
67
- )
68
-
69
- # Base metadata, common for both zip and individual files
70
- base_metadata: Dict[str, str] = {
71
- "upload": str(True).lower(),
72
- "user-agent": "aws-fsx-lustre",
73
- "file-owner": "537",
74
- "file-group": "500",
75
- "file-permissions": "100777",
76
- }
77
-
78
- if zip_compress and not compress_zip_name:
79
- compress_zip_name = str(repository_id)
80
-
81
- if zip_compress:
82
- # Create a temporary zip file
83
- with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as tmp_zip_file:
84
- zip_path = tmp_zip_file.name
85
- tmp_zip_file.close() # Close the file handle so zipfile can open it
86
-
87
- try:
88
- with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
89
- for file_path in files:
90
- # Add file to zip, using only the basename inside the zip
91
- zipf.write(file_path, arcname=os.path.basename(file_path))
92
-
93
- # Construct object name for the zip file
94
- object_name = f"{repository_id}/uploaded/{compress_zip_name}.zip"
95
-
96
- # For zip files, we use the base metadata without a specific checksum
97
- # (as checksums are per-file in the original design)
98
- extra_args = {
99
- "Metadata": base_metadata.copy()
100
- } # Use a copy to avoid modifying base_metadata
101
-
102
- if not upload_file_to_s3(
103
- s3_client, bucket_name, zip_path, object_name, extra_args=extra_args
104
- ):
105
- raise ValueError(
106
- f"Failed to upload zip file {zip_path} to S3 bucket {bucket_name}."
107
- )
108
-
109
- finally:
110
- # Clean up the temporary zip file
111
- if os.path.exists(zip_path):
112
- os.unlink(zip_path)
113
-
114
- else:
115
- # Original logic: upload files individually
116
- for file_path in files:
117
- file_basename = os.path.basename(file_path)
118
- object_name = f"{repository_id}/uploaded/{file_basename}"
119
-
120
- current_file_metadata = base_metadata.copy()
121
- if checksums:
122
- current_file_metadata["checksum"] = checksums.get(file_basename, "")
123
-
124
- extra_args = {"Metadata": current_file_metadata}
125
-
126
- if not upload_file_to_s3(
127
- s3_client, bucket_name, file_path, object_name, extra_args=extra_args
128
- ):
129
- raise ValueError(
130
- f"Failed to upload file {file_path} to S3 bucket {bucket_name}."
131
- )
132
-
133
- # Always upload .metadata files if the source 'files' list is provided
134
- if files:
135
- # Assuming all files in the 'files' list share the same parent directory,
136
- # which is the case data directory.
137
- data_directory = os.path.dirname(files[0])
138
- metadata_dir_local_path = os.path.join(data_directory, ".metadata")
139
-
140
- if os.path.isdir(metadata_dir_local_path):
141
- # Iterate through the original list of files to find corresponding metadata files
142
- for original_file_path in files:
143
- original_file_basename = os.path.basename(original_file_path)
144
- local_metadata_file_path = os.path.join(
145
- metadata_dir_local_path, original_file_basename
146
- )
147
-
148
- if os.path.isfile(local_metadata_file_path):
149
- # S3 object name for the metadata file (e.g., repository_id/.metadata/original_file_basename)
150
- s3_metadata_object_name = (
151
- f"{repository_id}/.metadata/{original_file_basename}"
152
- )
153
- extra_args = {"Metadata": base_metadata.copy()}
154
- if not upload_file_to_s3(
155
- s3_client,
156
- bucket_name,
157
- local_metadata_file_path,
158
- s3_metadata_object_name,
159
- extra_args=extra_args,
160
- ):
161
- raise ValueError(
162
- f"Failed to upload metadata file {local_metadata_file_path} to S3 bucket {bucket_name}."
163
- )
164
-
165
-
166
- def _download_s3_object(
167
- s3_client, bucket_name: str, s3_object_key: str, local_file_path: str
168
- ) -> bool:
169
- """
170
- Downloads a single object from S3 to a local file path.
171
-
172
- :param s3_client: Initialized S3 client.
173
- :param bucket_name: Name of the S3 bucket.
174
- :param s3_object_key: The key of the object in S3.
175
- :param local_file_path: The local path where the file should be saved.
176
- :return: True if download was successful, False otherwise.
177
- """
178
-
179
- try:
180
- s3_client.download_file(bucket_name, s3_object_key, local_file_path)
181
- return True
182
- except ClientError as e:
183
- print(f"ERROR: Failed to download {s3_object_key} from S3: {e}")
184
- return False
185
-
186
-
187
- def download_case_from_s3(
188
- repository_id: str,
189
- cluster_name: str, # Kept for consistency with caller, though not used directly in S3 ops
190
- access: str,
191
- secret: str,
192
- session_token: str,
193
- bucket_name: str,
194
- url: str, # S3 endpoint URL, used by _get_region
195
- output_path: str,
196
- file_list: List[str],
197
- ) -> List[str]:
198
- """
199
- Downloads files from an S3 bucket for a given case repository.
200
-
201
- It iterates through the provided `file_list`, downloads each specified file
202
- from the S3 path `{repository_id}/{file_in_list}`, preserving its relative path
203
- under `output_path`. It then checks if each downloaded file is gzipped,
204
- decompresses it if necessary, and returns a list of basenames of the
205
- final downloaded (and potentially decompressed) files.
206
-
207
- :param repository_id: The ID of the repository in S3.
208
- :param cluster_name: Name of the cluster (for context, not used in S3 calls).
209
- :param access: AWS access key ID.
210
- :param secret: AWS secret access key.
211
- :param session_token: AWS session token.
212
- :param bucket_name: Name of the S3 bucket.
213
- :param url: S3 service URL (used to determine region via _get_region).
214
- :param output_path: Local directory where files will be downloaded.
215
- :param file_list: A list of file names (basenames) to be downloaded.
216
- :return: A list of basenames of the downloaded (and decompressed) files.
217
- :raises ValueError: If S3 connection parameters are missing or filter is invalid.
218
- :raises RuntimeError: If S3 operations fail.
219
- """
220
- region = _get_region(url)
221
- if not all([region, access, secret, session_token, bucket_name]):
222
- # TODO: Replace print with proper logging
223
- print(
224
- "ERROR: Missing S3 connection parameters (region, access, secret, token, or bucket name)."
225
- )
226
- raise ValueError("Missing S3 connection parameters.")
227
-
228
- s3_client = boto3.client(
229
- "s3",
230
- aws_access_key_id=access,
231
- aws_secret_access_key=secret,
232
- aws_session_token=session_token,
233
- region_name=region,
234
- )
235
-
236
- downloaded_files: List[str] = []
237
-
238
- try:
239
- for file_name in file_list:
240
- # Construct the full S3 object key
241
- s3_object_key = f"{repository_id}/{file_name}"
242
-
243
- local_file_path = os.path.join(output_path, file_name)
244
- if _download_s3_object(
245
- s3_client, bucket_name, s3_object_key, local_file_path
246
- ):
247
- downloaded_files.append(os.path.basename(local_file_path))
248
-
249
- except ClientError as e:
250
- print(f"ERROR: S3 ClientError during download: {e}")
251
- raise RuntimeError(f"Failed to download files from S3: {e}")
252
- except Exception as e:
253
- print(f"ERROR: An unexpected error occurred during download: {e}")
254
- raise RuntimeError(f"An unexpected error occurred during S3 download: {e}")
255
-
256
- return downloaded_files