flyte 2.0.0b17__py3-none-any.whl → 2.0.0b18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

flyte/_logging.py CHANGED
@@ -75,39 +75,52 @@ def get_rich_handler(log_level: int) -> Optional[logging.Handler]:
75
75
  return handler
76
76
 
77
77
 
78
- def get_default_handler(log_level: int) -> logging.Handler:
79
- handler = logging.StreamHandler()
80
- handler.setLevel(log_level)
81
- formatter = logging.Formatter(fmt="[%(name)s] %(message)s")
82
- if log_format_from_env() == "json":
83
- pass
84
- # formatter = jsonlogger.JsonFormatter(fmt="%(asctime)s %(name)s %(levelname)s %(message)s")
85
- handler.setFormatter(formatter)
86
- return handler
87
-
88
-
89
78
  def initialize_logger(log_level: int = get_env_log_level(), enable_rich: bool = False):
90
79
  """
91
80
  Initializes the global loggers to the default configuration.
81
+ When enable_rich=True, upgrades to Rich handler for local CLI usage.
92
82
  """
93
83
  global logger # noqa: PLW0603
94
- logger = _create_logger("flyte", log_level, enable_rich)
95
84
 
85
+ # Clear existing handlers to reconfigure
86
+ root = logging.getLogger()
87
+ root.handlers.clear()
96
88
 
97
- def _create_logger(name: str, log_level: int = DEFAULT_LOG_LEVEL, enable_rich: bool = False) -> logging.Logger:
98
- """
99
- Creates a logger with the given name and log level.
100
- """
101
- logger = logging.getLogger(name)
102
- logger.setLevel(log_level)
103
- handler = None
104
- logger.handlers = []
89
+ flyte_logger = logging.getLogger("flyte")
90
+ flyte_logger.handlers.clear()
91
+
92
+ # Set up root logger handler
93
+ root_handler = None
94
+ if enable_rich:
95
+ root_handler = get_rich_handler(log_level)
96
+
97
+ if root_handler is None:
98
+ root_handler = logging.StreamHandler()
99
+
100
+ # Add context filter to root handler for all logging
101
+ root_handler.addFilter(ContextFilter())
102
+ root.addHandler(root_handler)
103
+
104
+ # Set up Flyte logger handler
105
+ flyte_handler = None
105
106
  if enable_rich:
106
- handler = get_rich_handler(log_level)
107
- if handler is None:
108
- handler = get_default_handler(log_level)
109
- logger.addHandler(handler)
110
- return logger
107
+ flyte_handler = get_rich_handler(log_level)
108
+
109
+ if flyte_handler is None:
110
+ flyte_handler = logging.StreamHandler()
111
+ flyte_handler.setLevel(log_level)
112
+ formatter = logging.Formatter(fmt="%(message)s")
113
+ flyte_handler.setFormatter(formatter)
114
+
115
+ # Add both filters to Flyte handler
116
+ flyte_handler.addFilter(FlyteInternalFilter())
117
+ flyte_handler.addFilter(ContextFilter())
118
+
119
+ flyte_logger.addHandler(flyte_handler)
120
+ flyte_logger.setLevel(log_level)
121
+ flyte_logger.propagate = False # Prevent double logging
122
+
123
+ logger = flyte_logger
111
124
 
112
125
 
113
126
  def log(fn=None, *, level=logging.DEBUG, entry=True, exit=True):
@@ -135,4 +148,75 @@ def log(fn=None, *, level=logging.DEBUG, entry=True, exit=True):
135
148
  return decorator(fn)
136
149
 
137
150
 
138
- logger = _create_logger("flyte", get_env_log_level())
151
+ class ContextFilter(logging.Filter):
152
+ """
153
+ A logging filter that adds the current action's run name and name to all log records.
154
+ Applied globally to capture context for both user and Flyte internal logging.
155
+ """
156
+
157
+ def filter(self, record):
158
+ from flyte._context import ctx
159
+
160
+ c = ctx()
161
+ if c:
162
+ action = c.action
163
+ record.msg = f"[{action.run_name}][{action.name}] {record.msg}"
164
+ return True
165
+
166
+
167
+ class FlyteInternalFilter(logging.Filter):
168
+ """
169
+ A logging filter that adds [flyte] prefix to internal Flyte logging only.
170
+ """
171
+
172
+ def filter(self, record):
173
+ if record.name.startswith("flyte"):
174
+ record.msg = f"[flyte] {record.msg}"
175
+ return True
176
+
177
+
178
+ def _setup_root_logger():
179
+ """
180
+ Configure the root logger to capture all logging with context information.
181
+ This ensures both user code and Flyte internal logging get the context.
182
+ """
183
+ root = logging.getLogger()
184
+ root.handlers.clear() # Remove any existing handlers to prevent double logging
185
+
186
+ # Create a basic handler for the root logger
187
+ handler = logging.StreamHandler()
188
+ # Add context filter to ALL logging
189
+ handler.addFilter(ContextFilter())
190
+
191
+ # Simple formatter since filters handle prefixes
192
+ root.addHandler(handler)
193
+
194
+
195
+ def _create_flyte_logger() -> logging.Logger:
196
+ """
197
+ Create the internal Flyte logger with [flyte] prefix.
198
+ """
199
+ flyte_logger = logging.getLogger("flyte")
200
+ flyte_logger.setLevel(get_env_log_level())
201
+
202
+ # Add a handler specifically for flyte logging with the prefix filter
203
+ handler = logging.StreamHandler()
204
+ handler.setLevel(get_env_log_level())
205
+ handler.addFilter(FlyteInternalFilter())
206
+ handler.addFilter(ContextFilter())
207
+
208
+ formatter = logging.Formatter(fmt="%(message)s")
209
+ handler.setFormatter(formatter)
210
+
211
+ # Prevent propagation to root to avoid double logging
212
+ flyte_logger.propagate = False
213
+ flyte_logger.addHandler(handler)
214
+
215
+ return flyte_logger
216
+
217
+
218
+ # Initialize root logger for global context
219
+ _setup_root_logger()
220
+
221
+ # Create the Flyte internal logger
222
+ logger = _create_flyte_logger()
flyte/_pod.py CHANGED
@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Dict, Optional
3
3
 
4
4
  if TYPE_CHECKING:
5
5
  from flyteidl.core.tasks_pb2 import K8sPod
6
- from kubernetes.client import ApiClient, V1PodSpec
6
+ from kubernetes.client import V1PodSpec
7
7
 
8
8
 
9
9
  _PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name"
@@ -21,6 +21,7 @@ class PodTemplate(object):
21
21
 
22
22
  def to_k8s_pod(self) -> "K8sPod":
23
23
  from flyteidl.core.tasks_pb2 import K8sObjectMetadata, K8sPod
24
+ from kubernetes.client import ApiClient
24
25
 
25
26
  return K8sPod(
26
27
  metadata=K8sObjectMetadata(labels=self.labels, annotations=self.annotations),
flyte/_run.py CHANGED
@@ -161,7 +161,12 @@ class _Runner:
161
161
  code_bundle = cached_value.code_bundle
162
162
  image_cache = cached_value.image_cache
163
163
  else:
164
- image_cache = await build_images.aio(cast(Environment, obj.parent_env()))
164
+ if not self._dry_run:
165
+ image_cache = await build_images.aio(cast(Environment, obj.parent_env()))
166
+ else:
167
+ from ._internal.imagebuild.image_builder import ImageCache
168
+
169
+ image_cache = ImageCache(image_lookup={})
165
170
 
166
171
  if self._interactive_mode:
167
172
  code_bundle = await build_pkl_bundle(
@@ -18,7 +18,7 @@ from typing import (
18
18
 
19
19
  import rich.repr
20
20
 
21
- from ._cache import CacheRequest
21
+ from ._cache import Cache, CacheRequest
22
22
  from ._doc import Documentation
23
23
  from ._environment import Environment
24
24
  from ._image import Image
@@ -74,6 +74,10 @@ class TaskEnvironment(Environment):
74
74
  super().__post_init__()
75
75
  if self.reusable is not None and self.plugin_config is not None:
76
76
  raise ValueError("Cannot set plugin_config when environment is reusable.")
77
+ if self.reusable and not isinstance(self.reusable, ReusePolicy):
78
+ raise TypeError(f"Expected reusable to be of type ReusePolicy, got {type(self.reusable)}")
79
+ if self.cache and not isinstance(self.cache, (str, Cache)):
80
+ raise TypeError(f"Expected cache to be of type str or Cache, got {type(self.cache)}")
77
81
 
78
82
  def clone_with(
79
83
  self,
flyte/_trace.py CHANGED
@@ -3,6 +3,7 @@ import inspect
3
3
  import time
4
4
  from typing import Any, AsyncGenerator, AsyncIterator, Awaitable, Callable, TypeGuard, TypeVar, Union, cast
5
5
 
6
+ from flyte._logging import logger
6
7
  from flyte.models import NativeInterface
7
8
 
8
9
  T = TypeVar("T")
@@ -33,10 +34,13 @@ def trace(func: Callable[..., T]) -> Callable[..., T]:
33
34
  iface = NativeInterface.from_callable(func)
34
35
  info, ok = await controller.get_action_outputs(iface, func, *args, **kwargs)
35
36
  if ok:
37
+ logger.info(f"Found existing trace info for {func}, {info}")
36
38
  if info.output:
37
39
  return info.output
38
40
  elif info.error:
39
41
  raise info.error
42
+ else:
43
+ logger.debug(f"No existing trace info found for {func}, proceeding to execute.")
40
44
  start_time = time.time()
41
45
  try:
42
46
  # Cast to Awaitable to satisfy mypy
@@ -44,6 +48,7 @@ def trace(func: Callable[..., T]) -> Callable[..., T]:
44
48
  results = await coroutine_result
45
49
  info.add_outputs(results, start_time=start_time, end_time=time.time())
46
50
  await controller.record_trace(info)
51
+ logger.debug(f"Finished trace for {func}, {info}")
47
52
  return results
48
53
  except Exception as e:
49
54
  # If there is an error, we need to record it
flyte/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '2.0.0b17'
32
- __version_tuple__ = version_tuple = (2, 0, 0, 'b17')
31
+ __version__ = version = '2.0.0b18'
32
+ __version_tuple__ = version_tuple = (2, 0, 0, 'b18')
33
33
 
34
- __commit_id__ = commit_id = 'gfe0ca1266'
34
+ __commit_id__ = commit_id = 'g930faeaea'
flyte/io/_dir.py CHANGED
@@ -48,6 +48,7 @@ class Dir(BaseModel, Generic[T], SerializableType):
48
48
  path: str
49
49
  name: Optional[str] = None
50
50
  format: str = ""
51
+ hash: Optional[str] = None
51
52
 
52
53
  class Config:
53
54
  arbitrary_types_allowed = True
@@ -248,13 +249,20 @@ class Dir(BaseModel, Generic[T], SerializableType):
248
249
  raise NotImplementedError("Sync download is not implemented for remote paths")
249
250
 
250
251
  @classmethod
251
- async def from_local(cls, local_path: Union[str, Path], remote_path: Optional[str] = None) -> Dir[T]:
252
+ async def from_local(
253
+ cls,
254
+ local_path: Union[str, Path],
255
+ remote_path: Optional[str] = None,
256
+ dir_cache_key: Optional[str] = None,
257
+ ) -> Dir[T]:
252
258
  """
253
259
  Asynchronously create a new Dir by uploading a local directory to the configured remote store.
254
260
 
255
261
  Args:
256
262
  local_path: Path to the local directory
257
263
  remote_path: Optional path to store the directory remotely. If None, a path will be generated.
264
+ dir_cache_key: If you have a precomputed hash value you want to use when computing cache keys for
265
+ discoverable tasks that this File is an input to.
258
266
 
259
267
  Returns:
260
268
  A new Dir instance pointing to the uploaded directory
@@ -262,13 +270,34 @@ class Dir(BaseModel, Generic[T], SerializableType):
262
270
  Example:
263
271
  ```python
264
272
  remote_dir = await Dir[DataFrame].from_local('/tmp/data_dir/', 's3://bucket/data/')
273
+ # With a known hash value you want to use for cache key calculation
274
+ remote_dir = await Dir[DataFrame].from_local('/tmp/data_dir/', 's3://bucket/data/', dir_cache_key='abc123')
265
275
  ```
266
276
  """
267
277
  local_path_str = str(local_path)
268
278
  dirname = os.path.basename(os.path.normpath(local_path_str))
269
279
 
270
280
  output_path = await storage.put(from_path=local_path_str, to_path=remote_path, recursive=True)
271
- return cls(path=output_path, name=dirname)
281
+ return cls(path=output_path, name=dirname, hash=dir_cache_key)
282
+
283
+ @classmethod
284
+ def from_existing_remote(cls, remote_path: str, dir_cache_key: Optional[str] = None) -> Dir[T]:
285
+ """
286
+ Create a Dir reference from an existing remote directory.
287
+
288
+ Args:
289
+ remote_path: The remote path to the existing directory
290
+ dir_cache_key: Optional hash value to use for cache key computation. If not specified,
291
+ the cache key will be computed based on this object's attributes.
292
+
293
+ Example:
294
+ ```python
295
+ remote_dir = Dir.from_existing_remote("s3://bucket/data/")
296
+ # With a known hash
297
+ remote_dir = Dir.from_existing_remote("s3://bucket/data/", dir_cache_key="abc123")
298
+ ```
299
+ """
300
+ return cls(path=remote_path, hash=dir_cache_key)
272
301
 
273
302
  @classmethod
274
303
  def from_local_sync(cls, local_path: Union[str, Path], remote_path: Optional[str] = None) -> Dir[T]:
@@ -414,7 +443,8 @@ class DirTransformer(TypeTransformer[Dir]):
414
443
  ),
415
444
  uri=python_val.path,
416
445
  )
417
- )
446
+ ),
447
+ hash=python_val.hash if python_val.hash else None,
418
448
  )
419
449
 
420
450
  async def to_python_value(
@@ -432,7 +462,8 @@ class DirTransformer(TypeTransformer[Dir]):
432
462
 
433
463
  uri = lv.scalar.blob.uri
434
464
  filename = Path(uri).name
435
- f: Dir = Dir(path=uri, name=filename, format=lv.scalar.blob.metadata.type.format)
465
+ hash_value = lv.hash if lv.hash else None
466
+ f: Dir = Dir(path=uri, name=filename, format=lv.scalar.blob.metadata.type.format, hash=hash_value)
436
467
  return f
437
468
 
438
469
  def guess_python_type(self, literal_type: types_pb2.LiteralType) -> Type[Dir]:
flyte/io/_file.py CHANGED
@@ -5,6 +5,7 @@ from contextlib import asynccontextmanager, contextmanager
5
5
  from pathlib import Path
6
6
  from typing import (
7
7
  IO,
8
+ Annotated,
8
9
  Any,
9
10
  AsyncGenerator,
10
11
  Dict,
@@ -21,12 +22,14 @@ from flyteidl.core import literals_pb2, types_pb2
21
22
  from fsspec.asyn import AsyncFileSystem
22
23
  from fsspec.utils import get_protocol
23
24
  from mashumaro.types import SerializableType
24
- from pydantic import BaseModel, model_validator
25
+ from pydantic import BaseModel, Field, model_validator
26
+ from pydantic.json_schema import SkipJsonSchema
25
27
 
26
28
  import flyte.storage as storage
27
29
  from flyte._context import internal_ctx
28
30
  from flyte._initialize import requires_initialization
29
31
  from flyte._logging import logger
32
+ from flyte.io._hashing_io import AsyncHashingReader, HashingWriter, HashMethod, PrecomputedValue
30
33
  from flyte.types import TypeEngine, TypeTransformer, TypeTransformerFailedError
31
34
 
32
35
  # Type variable for the file format
@@ -104,6 +107,8 @@ class File(BaseModel, Generic[T], SerializableType):
104
107
  path: str
105
108
  name: Optional[str] = None
106
109
  format: str = ""
110
+ hash: Optional[str] = None
111
+ hash_method: Annotated[Optional[HashMethod], Field(default=None, exclude=True), SkipJsonSchema()] = None
107
112
 
108
113
  class Config:
109
114
  arbitrary_types_allowed = True
@@ -139,7 +144,7 @@ class File(BaseModel, Generic[T], SerializableType):
139
144
 
140
145
  @classmethod
141
146
  @requires_initialization
142
- def new_remote(cls) -> File[T]:
147
+ def new_remote(cls, hash_method: Optional[HashMethod | str] = None) -> File[T]:
143
148
  """
144
149
  Create a new File reference for a remote file that will be written to.
145
150
 
@@ -155,11 +160,13 @@ class File(BaseModel, Generic[T], SerializableType):
155
160
  ```
156
161
  """
157
162
  ctx = internal_ctx()
163
+ known_cache_key = hash_method if isinstance(hash_method, str) else None
164
+ method = hash_method if isinstance(hash_method, HashMethod) else None
158
165
 
159
- return cls(path=ctx.raw_data.get_random_remote_path())
166
+ return cls(path=ctx.raw_data.get_random_remote_path(), hash=known_cache_key, hash_method=method)
160
167
 
161
168
  @classmethod
162
- def from_existing_remote(cls, remote_path: str) -> File[T]:
169
+ def from_existing_remote(cls, remote_path: str, file_cache_key: Optional[str] = None) -> File[T]:
163
170
  """
164
171
  Create a File reference from an existing remote file.
165
172
 
@@ -172,8 +179,10 @@ class File(BaseModel, Generic[T], SerializableType):
172
179
 
173
180
  Args:
174
181
  remote_path: The remote path to the existing file
182
+ file_cache_key: Optional hash value to use for discovery purposes. If not specified, the value of this
183
+ File object will be hashed (basically the path, not the contents).
175
184
  """
176
- return cls(path=remote_path)
185
+ return cls(path=remote_path, hash=file_cache_key)
177
186
 
178
187
  @asynccontextmanager
179
188
  async def open(
@@ -184,7 +193,7 @@ class File(BaseModel, Generic[T], SerializableType):
184
193
  cache_options: Optional[dict] = None,
185
194
  compression: Optional[str] = None,
186
195
  **kwargs,
187
- ) -> AsyncGenerator[IO[Any]]:
196
+ ) -> AsyncGenerator[Union[IO[Any], "HashingWriter"], None]:
188
197
  """
189
198
  Asynchronously open the file and return a file-like object.
190
199
 
@@ -245,7 +254,15 @@ class File(BaseModel, Generic[T], SerializableType):
245
254
  file_handle.close()
246
255
 
247
256
  with fs.open(self.path, mode) as file_handle:
248
- yield file_handle
257
+ if self.hash_method and self.hash is None:
258
+ logger.debug(f"Wrapping file handle with hashing writer using {self.hash_method}")
259
+ fh = HashingWriter(file_handle, accumulator=self.hash_method)
260
+ yield fh
261
+ self.hash = fh.result()
262
+ fh.close()
263
+ else:
264
+ yield file_handle
265
+ file_handle.close()
249
266
 
250
267
  def exists_sync(self) -> bool:
251
268
  """
@@ -351,13 +368,22 @@ class File(BaseModel, Generic[T], SerializableType):
351
368
 
352
369
  @classmethod
353
370
  @requires_initialization
354
- async def from_local(cls, local_path: Union[str, Path], remote_destination: Optional[str] = None) -> File[T]:
371
+ async def from_local(
372
+ cls,
373
+ local_path: Union[str, Path],
374
+ remote_destination: Optional[str] = None,
375
+ hash_method: Optional[HashMethod | str] = None,
376
+ ) -> File[T]:
355
377
  """
356
378
  Create a new File object from a local file that will be uploaded to the configured remote store.
357
379
 
358
380
  Args:
359
381
  local_path: Path to the local file
360
382
  remote_destination: Optional path to store the file remotely. If None, a path will be generated.
383
+ hash_method: Pass this argument either as a set string or a HashMethod to use for
384
+ determining a task's cache key if this File object is used as an input to said task. If not specified,
385
+ the cache key will just be computed based on this object's attributes (i.e. path, name, format, etc.).
386
+ If there is a set value you want to use, please pass an instance of the PrecomputedValue HashMethod.
361
387
 
362
388
  Returns:
363
389
  A new File instance pointing to the uploaded file
@@ -376,20 +402,38 @@ class File(BaseModel, Generic[T], SerializableType):
376
402
 
377
403
  # If remote_destination was not set by the user, and the configured raw data path is also local,
378
404
  # then let's optimize by not uploading.
405
+ hash_value = hash_method if isinstance(hash_method, str) else None
406
+ hash_method = hash_method if isinstance(hash_method, HashMethod) else None
379
407
  if "file" in protocol:
380
408
  if remote_destination is None:
381
409
  path = str(Path(local_path).absolute())
382
410
  else:
383
411
  # Otherwise, actually make a copy of the file
384
- async with aiofiles.open(remote_path, "rb") as src:
385
- async with aiofiles.open(local_path, "wb") as dst:
386
- await dst.write(await src.read())
412
+ async with aiofiles.open(local_path, "rb") as src:
413
+ async with aiofiles.open(remote_path, "wb") as dst:
414
+ if hash_method:
415
+ dst_wrapper = HashingWriter(dst, accumulator=hash_method)
416
+ await dst_wrapper.write(await src.read())
417
+ hash_value = dst_wrapper.result()
418
+ else:
419
+ await dst.write(await src.read())
387
420
  path = str(Path(remote_path).absolute())
388
421
  else:
389
422
  # Otherwise upload to remote using async storage layer
390
- path = await storage.put(str(local_path), remote_path)
423
+ if hash_method:
424
+ # We can skip the wrapper if the hash method is just a precomputed value
425
+ if not isinstance(hash_method, PrecomputedValue):
426
+ async with aiofiles.open(local_path, "rb") as src:
427
+ src_wrapper = AsyncHashingReader(src, accumulator=hash_method)
428
+ path = await storage.put_stream(src_wrapper, to_path=remote_path)
429
+ hash_value = src_wrapper.result()
430
+ else:
431
+ path = await storage.put(str(local_path), remote_path)
432
+ hash_value = hash_method.result()
433
+ else:
434
+ path = await storage.put(str(local_path), remote_path)
391
435
 
392
- f = cls(path=path, name=filename)
436
+ f = cls(path=path, name=filename, hash_method=hash_method, hash=hash_value)
393
437
  return f
394
438
 
395
439
 
@@ -432,7 +476,8 @@ class FileTransformer(TypeTransformer[File]):
432
476
  ),
433
477
  uri=python_val.path,
434
478
  )
435
- )
479
+ ),
480
+ hash=python_val.hash if python_val.hash else None,
436
481
  )
437
482
 
438
483
  async def to_python_value(
@@ -450,7 +495,8 @@ class FileTransformer(TypeTransformer[File]):
450
495
 
451
496
  uri = lv.scalar.blob.uri
452
497
  filename = Path(uri).name
453
- f: File = File(path=uri, name=filename, format=lv.scalar.blob.metadata.type.format)
498
+ hash_value = lv.hash if lv.hash else None
499
+ f: File = File(path=uri, name=filename, format=lv.scalar.blob.metadata.type.format, hash=hash_value)
454
500
  return f
455
501
 
456
502
  def guess_python_type(self, literal_type: types_pb2.LiteralType) -> Type[File]: