flyte 0.2.0b23__py3-none-any.whl → 0.2.0b25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

flyte/__init__.py CHANGED
@@ -1,19 +1,5 @@
1
1
  """
2
- Flyte SDK for authoring Compound AI applications, services and workflows.
3
-
4
- ## Environments
5
-
6
- TaskEnvironment class to define a new environment for a set of tasks.
7
-
8
- Example usage:
9
-
10
- ```python
11
- env = flyte.TaskEnvironment(name="my_env", image="my_image", resources=Resources(cpu="1", memory="1Gi"))
12
-
13
- @env.task
14
- async def my_task():
15
- pass
16
- ```
2
+ Flyte SDK for authoring compound AI applications, services and workflows.
17
3
  """
18
4
 
19
5
  __all__ = [
flyte/_bin/runtime.py CHANGED
@@ -32,7 +32,12 @@ ENABLE_REF_TASKS = "_REF_TASKS" # This is a temporary flag to enable reference
32
32
  _UNION_EAGER_API_KEY_ENV_VAR = "_UNION_EAGER_API_KEY"
33
33
 
34
34
 
35
- @click.command("a0")
35
+ @click.group()
36
+ def _pass_through():
37
+ pass
38
+
39
+
40
+ @_pass_through.command("a0")
36
41
  @click.option("--inputs", "-i", required=True)
37
42
  @click.option("--outputs-path", "-o", required=True)
38
43
  @click.option("--version", "-v", required=True)
@@ -158,3 +163,7 @@ def main(
158
163
  await controller.stop()
159
164
 
160
165
  asyncio.run(_run_and_stop())
166
+
167
+
168
+ if __name__ == "__main__":
169
+ _pass_through()
flyte/_environment.py CHANGED
@@ -49,6 +49,13 @@ class Environment:
49
49
  """
50
50
  Add a dependency to the environment.
51
51
  """
52
+ for e in env:
53
+ if not isinstance(e, Environment):
54
+ raise TypeError(f"Expected Environment, got {type(e)}")
55
+ if e.name == self.name:
56
+ raise ValueError("Cannot add self as a dependency")
57
+ if e in self.depends_on:
58
+ continue
52
59
  self.depends_on.extend(env)
53
60
 
54
61
  def clone_with(
flyte/_image.py CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  import base64
4
4
  import hashlib
5
5
  import sys
6
+ import typing
6
7
  from abc import abstractmethod
7
8
  from dataclasses import asdict, dataclass, field
8
9
  from functools import cached_property
@@ -353,16 +354,25 @@ class Image:
353
354
  layer.validate()
354
355
 
355
356
  @classmethod
356
- def _get_default_image_for(cls, python_version: Tuple[int, int], flyte_version: Optional[str] = None) -> Image:
357
+ def _get_default_image_for(
358
+ cls, python_version: Tuple[int, int], flyte_version: Optional[str] = None, install_flyte: bool = True
359
+ ) -> Image:
357
360
  # Would love a way to move this outside of this class (but still needs to be accessible via Image.auto())
358
361
  # this default image definition may need to be updated once there is a released pypi version
359
362
  from flyte._version import __version__
360
363
 
361
- dev_mode = (cls._is_editable_install() or (__version__ and "dev" in __version__)) and not flyte_version
362
- if flyte_version is None:
363
- flyte_version = __version__.replace("+", "-")
364
- preset_tag = flyte_version if flyte_version.startswith("v") else f"v{flyte_version}"
365
- preset_tag = f"py{python_version[0]}.{python_version[1]}-{preset_tag}"
364
+ dev_mode = (
365
+ (cls._is_editable_install() or (__version__ and "dev" in __version__))
366
+ and not flyte_version
367
+ and install_flyte
368
+ )
369
+ if install_flyte is False:
370
+ preset_tag = f"py{python_version[0]}.{python_version[1]}"
371
+ else:
372
+ if flyte_version is None:
373
+ flyte_version = __version__.replace("+", "-")
374
+ suffix = flyte_version if flyte_version.startswith("v") else f"v{flyte_version}"
375
+ preset_tag = f"py{python_version[0]}.{python_version[1]}-{suffix}"
366
376
  image = Image._new(
367
377
  base_image=f"python:{python_version[0]}.{python_version[1]}-slim-bookworm",
368
378
  registry=_BASE_REGISTRY,
@@ -389,14 +399,15 @@ class Image:
389
399
  )
390
400
  image = image.with_apt_packages("build-essential", "ca-certificates")
391
401
 
392
- # Add in flyte library
393
- if dev_mode:
394
- image = image._with_local_v2()
395
- else:
396
- if Version(flyte_version).is_devrelease or Version(flyte_version).is_prerelease:
397
- image = image.with_pip_packages(f"flyte=={flyte_version}", pre=True)
402
+ if install_flyte:
403
+ if dev_mode:
404
+ image = image.with_local_v2()
398
405
  else:
399
- image = image.with_pip_packages(f"flyte=={flyte_version}")
406
+ flyte_version = typing.cast(str, flyte_version)
407
+ if Version(flyte_version).is_devrelease or Version(flyte_version).is_prerelease:
408
+ image = image.with_pip_packages(f"flyte=={flyte_version}", pre=True)
409
+ else:
410
+ image = image.with_pip_packages(f"flyte=={flyte_version}")
400
411
  object.__setattr__(image, "_tag", preset_tag)
401
412
  # Set this to auto for all auto images because the meaning of "auto" can change (based on logic inside
402
413
  # _get_default_image_for, acts differently in a running task container) so let's make sure it stays auto.
@@ -416,6 +427,7 @@ class Image:
416
427
  cls,
417
428
  python_version: Optional[Tuple[int, int]] = None,
418
429
  flyte_version: Optional[str] = None,
430
+ install_flyte: bool = True,
419
431
  registry: Optional[str] = None,
420
432
  name: Optional[str] = None,
421
433
  ) -> Image:
@@ -425,6 +437,7 @@ class Image:
425
437
 
426
438
  :param python_version: If not specified, will use the current Python version
427
439
  :param flyte_version: Union version to use
440
+ :param install_flyte: If True, will install the flyte library in the image
428
441
  :param registry: Registry to use for the image
429
442
  :param name: Name of the image if you want to override the default name
430
443
 
@@ -433,7 +446,9 @@ class Image:
433
446
  if python_version is None:
434
447
  python_version = _detect_python_version()
435
448
 
436
- base_image = cls._get_default_image_for(python_version=python_version, flyte_version=flyte_version)
449
+ base_image = cls._get_default_image_for(
450
+ python_version=python_version, flyte_version=flyte_version, install_flyte=install_flyte
451
+ )
437
452
 
438
453
  if registry and name:
439
454
  return base_image.clone(registry=registry, name=name)
@@ -537,6 +552,7 @@ class Image:
537
552
  self,
538
553
  registry: Optional[str] = None,
539
554
  name: Optional[str] = None,
555
+ python_version: Optional[Tuple[int, int]] = None,
540
556
  addl_layer: Optional[Layer] = None,
541
557
  ) -> Image:
542
558
  """
@@ -544,17 +560,16 @@ class Image:
544
560
 
545
561
  :param registry: Registry to use for the image
546
562
  :param name: Name of the image
563
+ :param python_version: Python version for the image, if not specified, will use the current Python version
547
564
  :param addl_layer: Additional layer to add to the image. This will be added to the end of the layers.
548
- :param addl_layer: Additional layer to add to the image. This will be added on top of the existing layers.
549
565
 
550
566
  :return:
551
567
  """
552
568
  registry = registry if registry else self.registry
553
569
  name = name if name else self.name
554
- if addl_layer and (not name or not registry):
570
+ if addl_layer and (not name):
555
571
  raise ValueError(
556
- f"Cannot add additional layer {addl_layer} to an image"
557
- f" without a registry and name. Please first clone()."
572
+ f"Cannot add additional layer {addl_layer} to an image without name. Please first clone()."
558
573
  )
559
574
  new_layers = (*self._layers, addl_layer) if addl_layer else self._layers
560
575
  img = Image._new(
@@ -563,7 +578,7 @@ class Image:
563
578
  registry=registry,
564
579
  name=name,
565
580
  platform=self.platform,
566
- python_version=self.python_version,
581
+ python_version=python_version or self.python_version,
567
582
  _layers=new_layers,
568
583
  )
569
584
 
@@ -713,7 +728,7 @@ class Image:
713
728
  :param dst: destination folder in the image
714
729
  :return: Image
715
730
  """
716
- new_image = self.clone(addl_layer=CopyConfig(path_type=1, src=src, dst=dst))
731
+ new_image = self.clone(addl_layer=CopyConfig(path_type=1, src=src, dst=dst, _compute_identifier=lambda x: dst))
717
732
  return new_image
718
733
 
719
734
  def with_source_file(self, src: Path, dst: str = ".") -> Image:
@@ -725,7 +740,7 @@ class Image:
725
740
  :param dst: destination folder in the image
726
741
  :return: Image
727
742
  """
728
- new_image = self.clone(addl_layer=CopyConfig(path_type=0, src=src, dst=dst))
743
+ new_image = self.clone(addl_layer=CopyConfig(path_type=0, src=src, dst=dst, _compute_identifier=lambda x: dst))
729
744
  return new_image
730
745
 
731
746
  def with_uv_project(
@@ -777,7 +792,7 @@ class Image:
777
792
  new_image = self.clone(addl_layer=Commands(commands=new_commands))
778
793
  return new_image
779
794
 
780
- def _with_local_v2(self) -> Image:
795
+ def with_local_v2(self) -> Image:
781
796
  """
782
797
  Use this method to create a new image with the local v2 builder
783
798
  This will override any existing builder
@@ -1,12 +1,15 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
+ from typing import Any, Literal
4
5
 
5
6
  from flyteidl.core import execution_pb2
6
7
 
7
8
  from flyte._protos.workflow import run_definition_pb2, state_service_pb2, task_definition_pb2
8
9
  from flyte.models import GroupData
9
10
 
11
+ ActionType = Literal["task", "trace"]
12
+
10
13
 
11
14
  @dataclass
12
15
  class Action:
@@ -17,6 +20,7 @@ class Action:
17
20
 
18
21
  action_id: run_definition_pb2.ActionIdentifier
19
22
  parent_action_name: str
23
+ type: ActionType = "task" # type of action, task or trace
20
24
  friendly_name: str | None = None
21
25
  group: GroupData | None = None
22
26
  task: task_definition_pb2.TaskSpec | None = None
@@ -144,3 +148,28 @@ class Action:
144
148
  err=obj.error if obj.HasField("error") else None,
145
149
  realized_outputs_uri=obj.output_uri,
146
150
  )
151
+
152
+ @classmethod
153
+ def from_trace(
154
+ cls,
155
+ parent_action_name: str,
156
+ action_id: run_definition_pb2.ActionIdentifier,
157
+ friendly_name: str,
158
+ group_data: GroupData | None,
159
+ trace_spec: Any, # TODO
160
+ inputs_uri: str,
161
+ outputs_uri: str,
162
+ ) -> Action:
163
+ """
164
+ This creates a new action for tracing purposes. It is used to track the execution of a trace.
165
+ """
166
+ return cls(
167
+ action_id=action_id,
168
+ parent_action_name=parent_action_name,
169
+ type="trace",
170
+ friendly_name=friendly_name,
171
+ group=group_data,
172
+ inputs_uri=inputs_uri,
173
+ realized_outputs_uri=outputs_uri,
174
+ phase=run_definition_pb2.Phase.PHASE_SUCCEEDED,
175
+ )
@@ -174,13 +174,12 @@ class RemoteController(Controller):
174
174
  )
175
175
 
176
176
  task_spec = translate_task_to_wire(_task, new_serialization_context)
177
- serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
178
-
179
- inputs_hash = convert.generate_inputs_hash(serialized_inputs)
177
+ inputs_hash = convert.generate_inputs_hash_from_proto(inputs.proto_inputs)
180
178
  sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
181
179
  tctx, task_spec, inputs_hash, _task_call_seq
182
180
  )
183
181
 
182
+ serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
184
183
  inputs_uri = io.inputs_path(sub_action_output_path)
185
184
  await upload_inputs_with_retry(serialized_inputs, inputs_uri)
186
185
 
@@ -391,10 +390,12 @@ class RemoteController(Controller):
391
390
  if tctx is None:
392
391
  raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
393
392
 
393
+ current_action_id = tctx.action
394
394
  current_output_path = tctx.output_path
395
395
  sub_run_output_path = storage.join(current_output_path, info.action.name)
396
396
 
397
397
  if info.interface.has_outputs():
398
+ outputs_file_path: str = ""
398
399
  if info.output:
399
400
  outputs = await convert.convert_from_native_to_outputs(info.output, info.interface)
400
401
  outputs_file_path = io.outputs_path(sub_run_output_path)
@@ -405,6 +406,33 @@ class RemoteController(Controller):
405
406
  await io.upload_error(err.err, error_path)
406
407
  else:
407
408
  raise flyte.errors.RuntimeSystemError("BadTraceInfo", "Trace info does not have output or error")
409
+ trace_action = Action.from_trace(
410
+ parent_action_name=current_action_id.name,
411
+ action_id=run_definition_pb2.ActionIdentifier(
412
+ name=info.action.name,
413
+ run=run_definition_pb2.RunIdentifier(
414
+ name=current_action_id.run_name,
415
+ project=current_action_id.project,
416
+ domain=current_action_id.domain,
417
+ org=current_action_id.org,
418
+ ),
419
+ ),
420
+ inputs_uri=info.inputs_path,
421
+ outputs_uri=outputs_file_path,
422
+ trace_spec="",
423
+ friendly_name=info.name,
424
+ group_data=tctx.group_data,
425
+ )
426
+ try:
427
+ logger.info(
428
+ f"Submitting Trace action Run:[{trace_action.run_name}, Parent:[{trace_action.parent_action_name}],"
429
+ f" Trace fn:[{info.name}], action:[{info.action.name}]"
430
+ )
431
+ await self.submit_action(trace_action)
432
+ logger.info(f"Trace Action for [{info.name}] action id: {info.action.name}, completed!")
433
+ except asyncio.CancelledError:
434
+ # If the action is cancelled, we need to cancel the action on the server as well
435
+ raise
408
436
 
409
437
  async def submit_task_ref(self, _task: task_definition_pb2.TaskDetails, *args, **kwargs) -> Any:
410
438
  ctx = internal_ctx()
@@ -420,12 +448,12 @@ class RemoteController(Controller):
420
448
  _task.spec.task_template.interface, default_inputs=_task.spec.default_inputs
421
449
  )
422
450
  inputs = await convert.convert_from_native_to_inputs(native_interface, *args, **kwargs)
423
- serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
424
- inputs_hash = convert.generate_inputs_hash(serialized_inputs)
451
+ inputs_hash = convert.generate_inputs_hash_from_proto(inputs.proto_inputs)
425
452
  sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
426
453
  tctx, task_name, inputs_hash, invoke_seq_num
427
454
  )
428
455
 
456
+ serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
429
457
  inputs_uri = io.inputs_path(sub_action_output_path)
430
458
  await upload_inputs_with_retry(serialized_inputs, inputs_uri)
431
459
  # cache key - task name, task signature, inputs, cache version
@@ -309,29 +309,43 @@ class Controller:
309
309
  """
310
310
  Attempt to launch an action.
311
311
  """
312
- if not action.is_started() and action.task is not None:
313
- logger.debug(f"Attempting to launch action: {action.name}")
314
- try:
312
+ if not action.is_started():
313
+ task: queue_service_pb2.TaskAction | None = None
314
+ trace: queue_service_pb2.TraceAction | None = None
315
+ if action.type == "task":
316
+ if action.task is None:
317
+ raise flyte.errors.RuntimeSystemError(
318
+ "NoTaskSpec", "Task Spec not found, cannot launch Task Action."
319
+ )
315
320
  cache_key = None
316
321
  logger.info(f"Action {action.name} has cache version {action.cache_key}")
317
322
  if action.cache_key:
318
323
  cache_key = StringValue(value=action.cache_key)
319
324
 
325
+ task = queue_service_pb2.TaskAction(
326
+ id=task_definition_pb2.TaskIdentifier(
327
+ version=action.task.task_template.id.version,
328
+ org=action.task.task_template.id.org,
329
+ project=action.task.task_template.id.project,
330
+ domain=action.task.task_template.id.domain,
331
+ name=action.task.task_template.id.name,
332
+ ),
333
+ spec=action.task,
334
+ cache_key=cache_key,
335
+ )
336
+ else:
337
+ trace = queue_service_pb2.TraceAction(
338
+ name=action.friendly_name,
339
+ )
340
+
341
+ logger.debug(f"Attempting to launch action: {action.name}")
342
+ try:
320
343
  await self._queue_service.EnqueueAction(
321
344
  queue_service_pb2.EnqueueActionRequest(
322
345
  action_id=action.action_id,
323
346
  parent_action_name=action.parent_action_name,
324
- task=queue_service_pb2.TaskAction(
325
- id=task_definition_pb2.TaskIdentifier(
326
- version=action.task.task_template.id.version,
327
- org=action.task.task_template.id.org,
328
- project=action.task.task_template.id.project,
329
- domain=action.task.task_template.id.domain,
330
- name=action.task.task_template.id.name,
331
- ),
332
- spec=action.task,
333
- cache_key=cache_key,
334
- ),
347
+ task=task,
348
+ trace=trace,
335
349
  input_uri=action.inputs_uri,
336
350
  run_output_base=action.run_output_base,
337
351
  group=action.group.name if action.group else None,
@@ -1,15 +1,3 @@
1
- import asyncio
2
- from typing import List
3
-
4
- from flyte._image import Image
5
- from flyte._internal.imagebuild.docker_builder import DockerImageBuilder
6
1
  from flyte._internal.imagebuild.image_builder import ImageBuildEngine
7
- from flyte._internal.imagebuild.remote_builder import RemoteImageBuilder
8
-
9
- __all__ = ["DockerImageBuilder", "ImageBuildEngine", "RemoteImageBuilder"]
10
-
11
2
 
12
- async def build(images: List[Image]) -> List[str]:
13
- builder = DockerImageBuilder()
14
- ts = [asyncio.create_task(builder.build_image(image)) for image in images]
15
- return list(await asyncio.gather(*ts))
3
+ __all__ = ["ImageBuildEngine"]
@@ -456,6 +456,10 @@ class DockerImageBuilder(ImageBuilder):
456
456
  else:
457
457
  click.secho(f"Run command: {concat_command} ", fg="blue")
458
458
 
459
- await asyncio.to_thread(subprocess.run, command, check=True)
459
+ try:
460
+ await asyncio.to_thread(subprocess.run, command, check=True)
461
+ except subprocess.CalledProcessError as e:
462
+ logger.error(f"Failed to build image: {e}")
463
+ raise RuntimeError(f"Failed to build image: {e}")
460
464
 
461
465
  return image.uri
@@ -138,9 +138,17 @@ async def _validate_configuration(image: Image) -> Tuple[str, Optional[str]]:
138
138
 
139
139
  if any(context_path.iterdir()):
140
140
  # If there are files in the context directory, upload it
141
- _, context_url = await remote.upload_file.aio(
142
- Path(shutil.make_archive(str(tmp_path / "context"), "xztar", context_path))
143
- )
141
+ archive = Path(shutil.make_archive(str(tmp_path / "context"), "xztar", context_path))
142
+ st = archive.stat()
143
+ if st.st_size > 5 * 1024 * 1024:
144
+ logger.warning(
145
+ click.style(
146
+ f"Context size is {st.st_size / (1024 * 1024):.2f} MB, which is larger than 5 MB. "
147
+ "Upload and build speed will be impacted.",
148
+ fg="yellow",
149
+ )
150
+ )
151
+ _, context_url = await remote.upload_file.aio(archive)
144
152
  else:
145
153
  context_url = ""
146
154
 
@@ -155,9 +155,18 @@ async def convert_from_native_to_outputs(o: Any, interface: NativeInterface, tas
155
155
  if not isinstance(o, tuple):
156
156
  o = (o,)
157
157
 
158
- assert len(interface.outputs) == len(interface.outputs), (
159
- f"Received {len(o)} outputs but interface has {len(interface.outputs)}"
160
- )
158
+ if len(interface.outputs) == 0:
159
+ if len(o) != 0:
160
+ if len(o) == 1 and o[0] is not None:
161
+ raise flyte.errors.RuntimeDataValidationError(
162
+ "o0",
163
+ f"Expected no outputs but got {o},did you miss a return type annotation?",
164
+ task_name,
165
+ )
166
+ else:
167
+ assert len(o) == len(interface.outputs), (
168
+ f"Received {len(o)} outputs but return annotation has {len(interface.outputs)} outputs specified. "
169
+ )
161
170
  named = []
162
171
  for (output_name, python_type), v in zip(interface.outputs.items(), o):
163
172
  try:
@@ -279,6 +288,29 @@ def generate_inputs_hash(serialized_inputs: str | bytes) -> str:
279
288
  return hash_data(serialized_inputs)
280
289
 
281
290
 
291
+ def generate_inputs_hash_from_proto(inputs: run_definition_pb2.Inputs) -> str:
292
+ """
293
+ Generate a hash for the inputs. This is used to uniquely identify the inputs for a task.
294
+ :param inputs: The inputs to hash.
295
+ :return: A hexadecimal string representation of the hash.
296
+ """
297
+ if not inputs:
298
+ return ""
299
+ return generate_inputs_hash(inputs.SerializeToString(deterministic=True))
300
+
301
+
302
+ def generate_interface_hash(task_interface: interface_pb2.TypedInterface) -> str:
303
+ """
304
+ Generate a hash for the task interface. This is used to uniquely identify the task interface.
305
+ :param task_interface: The interface of the task.
306
+ :return: A hexadecimal string representation of the hash.
307
+ """
308
+ if not task_interface:
309
+ return ""
310
+ serialized_interface = task_interface.SerializeToString(deterministic=True)
311
+ return hash_data(serialized_interface)
312
+
313
+
282
314
  def generate_cache_key_hash(
283
315
  task_name: str,
284
316
  inputs_hash: str,
@@ -302,10 +334,13 @@ def generate_cache_key_hash(
302
334
  if ignored_input_vars:
303
335
  filtered = [named_lit for named_lit in proto_inputs.literals if named_lit.name not in ignored_input_vars]
304
336
  final = run_definition_pb2.Inputs(literals=filtered)
305
- final_inputs = final.SerializeToString(deterministic=True)
337
+ final_inputs = generate_inputs_hash_from_proto(final)
306
338
  else:
307
339
  final_inputs = inputs_hash
308
- data = f"{final_inputs}{task_name}{task_interface.SerializeToString(deterministic=True)}{cache_version}"
340
+
341
+ interface_hash = generate_interface_hash(task_interface)
342
+
343
+ data = f"{final_inputs}{task_name}{interface_hash}{cache_version}"
309
344
  return hash_data(data)
310
345
 
311
346
 
@@ -1,4 +1,5 @@
1
- from typing import List, Optional, Tuple
1
+ import importlib
2
+ from typing import List, Optional, Tuple, Type
2
3
 
3
4
  import flyte.errors
4
5
  from flyte._code_bundle import download_bundle
@@ -10,7 +11,6 @@ from flyte._task import TaskTemplate
10
11
  from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath
11
12
 
12
13
  from .convert import Error, Inputs, Outputs
13
- from .task_serde import load_task
14
14
  from .taskrunner import (
15
15
  convert_and_run,
16
16
  extract_download_run_upload,
@@ -51,25 +51,68 @@ async def direct_dispatch(
51
51
  )
52
52
 
53
53
 
54
+ def load_class(qualified_name) -> Type:
55
+ """
56
+ Load a class from a qualified name. The qualified name should be in the format 'module.ClassName'.
57
+ :param qualified_name: The qualified name of the class to load.
58
+ :return: The class object.
59
+ """
60
+ module_name, class_name = qualified_name.rsplit(".", 1) # Split module and class
61
+ module = importlib.import_module(module_name) # Import the module
62
+ return getattr(module, class_name) # Retrieve the class
63
+
64
+
65
+ def load_task(resolver: str, *resolver_args: str) -> TaskTemplate:
66
+ """
67
+ Load a task from a resolver. This is a placeholder function.
68
+
69
+ :param resolver: The resolver to use to load the task.
70
+ :param resolver_args: Arguments to pass to the resolver.
71
+ :return: The loaded task.
72
+ """
73
+ resolver_class = load_class(resolver)
74
+ resolver_instance = resolver_class()
75
+ return resolver_instance.load_task(resolver_args)
76
+
77
+
78
+ def load_pkl_task(code_bundle: CodeBundle) -> TaskTemplate:
79
+ """
80
+ Loads a task from a pickled code bundle.
81
+ :param code_bundle: The code bundle to load the task from.
82
+ :return: The loaded task template.
83
+ """
84
+ logger.debug(f"Loading task from pkl: {code_bundle.downloaded_path}")
85
+ try:
86
+ import gzip
87
+
88
+ import cloudpickle
89
+
90
+ with gzip.open(str(code_bundle.downloaded_path), "rb") as f:
91
+ return cloudpickle.load(f)
92
+ except Exception as e:
93
+ logger.exception(f"Failed to load pickled task from {code_bundle.downloaded_path}. Reason: {e!s}")
94
+ raise
95
+
96
+
97
+ async def download_code_bundle(code_bundle: CodeBundle) -> CodeBundle:
98
+ """
99
+ Downloads the code bundle if it is not already downloaded.
100
+ :param code_bundle: The code bundle to download.
101
+ :return: The code bundle with the downloaded path.
102
+ """
103
+ logger.debug(f"Downloading {code_bundle}")
104
+ downloaded_path = await download_bundle(code_bundle)
105
+ return code_bundle.with_downloaded_path(downloaded_path)
106
+
107
+
54
108
  async def _download_and_load_task(
55
109
  code_bundle: CodeBundle | None, resolver: str | None = None, resolver_args: List[str] | None = None
56
110
  ) -> TaskTemplate:
57
111
  if code_bundle and (code_bundle.tgz or code_bundle.pkl):
58
112
  logger.debug(f"Downloading {code_bundle}")
59
- downloaded_path = await download_bundle(code_bundle)
60
- code_bundle = code_bundle.with_downloaded_path(downloaded_path)
113
+ code_bundle = await download_code_bundle(code_bundle)
61
114
  if code_bundle.pkl:
62
- try:
63
- logger.debug(f"Loading task from pkl: {code_bundle.downloaded_path}")
64
- import gzip
65
-
66
- import cloudpickle
67
-
68
- with gzip.open(str(code_bundle.downloaded_path), "rb") as f:
69
- return cloudpickle.load(f)
70
- except Exception as e:
71
- logger.exception(f"Failed to load pickled task from {code_bundle.downloaded_path}. Reason: {e!s}")
72
- raise
115
+ return load_pkl_task(code_bundle)
73
116
 
74
117
  if not resolver or not resolver_args:
75
118
  raise flyte.errors.RuntimeSystemError(