langfun 0.1.2.dev202501150804__py3-none-any.whl → 0.1.2.dev202501170804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
langfun/core/__init__.py CHANGED
@@ -75,7 +75,6 @@ from langfun.core.sampling import random_sample
75
75
  from langfun.core.concurrent import RetryEntry
76
76
  from langfun.core.concurrent import concurrent_execute
77
77
  from langfun.core.concurrent import concurrent_map
78
- from langfun.core.concurrent import with_context_access
79
78
  from langfun.core.concurrent import with_retry
80
79
 
81
80
  # Interface for natural language formattable.
langfun/core/component.py CHANGED
@@ -13,10 +13,7 @@
13
13
  # limitations under the License.
14
14
  """langfun Component."""
15
15
 
16
- import contextlib
17
- import dataclasses
18
- import threading
19
- from typing import Annotated, Any, ContextManager, Iterator, Type
16
+ from typing import ContextManager
20
17
  import pyglove as pg
21
18
 
22
19
 
@@ -24,22 +21,9 @@ import pyglove as pg
24
21
  RAISE_IF_HAS_ERROR = (pg.MISSING_VALUE,)
25
22
 
26
23
 
27
- class Component(pg.Object):
24
+ class Component(pg.ContextualObject):
28
25
  """Base class for langfun components."""
29
26
 
30
- # Override __repr__ format to use inferred values when available.
31
- __repr_format_kwargs__ = dict(
32
- compact=True,
33
- use_inferred=True,
34
- )
35
-
36
- # Override __str__ format to use inferred values when available.
37
- __str_format_kwargs__ = dict(
38
- compact=False,
39
- verbose=False,
40
- use_inferred=True,
41
- )
42
-
43
27
  # Allow symbolic assignment, which invalidates the object and recomputes
44
28
  # states upon update.
45
29
  allow_symbolic_assignment = True
@@ -75,122 +59,23 @@ class Component(pg.Object):
75
59
  if additional_fields:
76
60
  cls.update_schema(additional_fields)
77
61
 
78
- def _on_bound(self):
79
- super()._on_bound()
80
- self._tls = threading.local()
81
-
82
- def _sym_inferred(self, key: str, **kwargs):
83
- """Override to allow attribute to access scoped value.
84
-
85
- Args:
86
- key: attribute name.
87
- **kwargs: Optional keyword arguments for value inference.
88
-
89
- Returns:
90
- The value of the symbolic attribute. If not available, returns the
91
- default value.
92
-
93
- Raises:
94
- AttributeError: If the attribute does not exist or contextual attribute
95
- is not ready.
96
- """
97
- if key not in self._sym_attributes:
98
- raise AttributeError(key)
99
-
100
- # Step 1: Try use value from `self.override`.
101
- # The reason is that `self.override` is short-lived and explicitly specified
102
- # by the user in scenarios like `LangFunc.render`, which should not be
103
- # affected by `lf.context`.
104
- v = _get_scoped_value(self._tls, _CONTEXT_OVERRIDES, key)
105
- if v is not None:
106
- return v.value
107
-
108
- # Step 2: Try use value from `lf.context` with `override_attrs`.
109
- # This gives users a chance to override the bound attributes of components
110
- # from the top, allowing change of bindings without modifying the code
111
- # that produces the components.
112
- override = get_contextual_override(key)
113
- if override and override.override_attrs:
114
- return override.value
115
-
116
- # Step 3: Try use value from the symbolic tree, starting from self to
117
- # the root of the tree.
118
- # Step 4: If the value is not present, use the value from `context()` (
119
- # override_attrs=False).
120
- # Step 5: Otherwise use the default value from `ContextualAttribute`.
121
- return super()._sym_inferred(key, context_override=override, **kwargs)
122
-
123
- def override(
124
- self, **kwargs) -> ContextManager[dict[str, 'ContextualOverride']]:
125
- """Context manager to override the attributes of this component."""
126
- vs = {k: ContextualOverride(v) for k, v in kwargs.items()}
127
- return _contextual_scope(self._tls, _CONTEXT_OVERRIDES, **vs)
128
-
129
- def __getattribute__(self, name: str) -> Any:
130
- """Override __getattribute__ to deal with class attribute override."""
131
- if not name.startswith('_') and hasattr(self.__class__, name):
132
- tls = self.__dict__.get('_tls', None)
133
- if tls is not None:
134
- v = _get_scoped_value(tls, _CONTEXT_OVERRIDES, name)
135
- if v is not None:
136
- return v.value
137
- return super().__getattribute__(name)
138
-
139
-
140
- _global_tls = threading.local()
141
-
142
- _CONTEXT_OVERRIDES = 'context_overrides'
143
-
144
62
 
145
- @dataclasses.dataclass(frozen=True)
146
- class ContextualOverride:
147
- """Value marker for contextual override for an attribute."""
63
+ # Aliases from PyGlove for ease of access.
64
+ context = pg.contextual_override
65
+ get_contextual_override = pg.utils.get_contextual_override
66
+ context_value = pg.utils.contextual_value
67
+ all_contextual_values = pg.utils.all_contextual_values
68
+ contextual = pg.contextual_attribute
148
69
 
149
- # Overridden value.
150
- value: Any
151
-
152
- # If True, this override will apply to both current scope and nested scope,
153
- # meaning current `lf.context` will take precedence over all nested
154
- # `lf.context` on this attribute.
155
- cascade: bool = False
156
-
157
- # If True, this override will apply to attributes that already have values.
158
- override_attrs: bool = False
159
-
160
-
161
- def context(
162
- *,
163
- cascade: bool = False,
164
- override_attrs: bool = False,
165
- **variables,
166
- ) -> ContextManager[dict[str, ContextualOverride]]:
167
- """Context manager to provide overriden values for contextual attributes.
168
-
169
- Args:
170
- cascade: If True, this override will apply to both current scope and nested
171
- scope, meaning that this `lf.context` will take precedence over all
172
- nested `lf.context` on the overriden variables.
173
- override_attrs: If True, this override will apply to attributes that already
174
- have values. Otherwise overridden variables will only be used for
175
- contextual attributes whose values are not present.
176
- **variables: Key/values as override for contextual attributes.
177
-
178
- Returns:
179
- A dict of attribute names to their contextual overrides.
180
- """
181
- vs = {}
182
- for k, v in variables.items():
183
- if not isinstance(v, ContextualOverride):
184
- v = ContextualOverride(v, cascade, override_attrs)
185
- vs[k] = v
186
- return _contextual_scope(_global_tls, _CONTEXT_OVERRIDES, **vs)
70
+ # Decorator for setting the positional arguments for Component.
71
+ use_init_args = pg.use_init_args
187
72
 
188
73
 
189
74
  def use_settings(
190
75
  *,
191
76
  cascade: bool = False,
192
77
  **settings,
193
- ) -> ContextManager[dict[str, ContextualOverride]]:
78
+ ) -> ContextManager[dict[str, pg.utils.ContextualOverride]]:
194
79
  """Shortcut method for overriding component attributes.
195
80
 
196
81
  Args:
@@ -203,150 +88,3 @@ def use_settings(
203
88
  A dict of attribute names to their contextual overrides.
204
89
  """
205
90
  return context(cascade=cascade, override_attrs=True, **settings)
206
-
207
-
208
- def get_contextual_override(var_name: str) -> ContextualOverride | None:
209
- """Returns the overriden contextual value in current scope."""
210
- return _get_scoped_value(_global_tls, _CONTEXT_OVERRIDES, var_name)
211
-
212
-
213
- def context_value(var_name: str, default: Any = RAISE_IF_HAS_ERROR) -> Any:
214
- """Returns the value of a variable defined in `lf.context`."""
215
- override = get_contextual_override(var_name)
216
- if override is None:
217
- if default == RAISE_IF_HAS_ERROR:
218
- raise KeyError(f'{var_name!r} does not exist in current context.')
219
- return default
220
- return override.value
221
-
222
-
223
- def all_contextual_values() -> dict[str, Any]:
224
- """Returns all contextual values provided from `lf.context` in scope."""
225
- overrides = getattr(_global_tls, _CONTEXT_OVERRIDES, {})
226
- return {k: v.value for k, v in overrides.items()}
227
-
228
-
229
- @contextlib.contextmanager
230
- def _contextual_scope(
231
- tls: threading.local, tls_key, **variables
232
- ) -> Iterator[dict[str, ContextualOverride]]:
233
- """Context manager to set variables within a scope."""
234
- previous_values = getattr(tls, tls_key, {})
235
- current_values = dict(previous_values)
236
- for k, v in variables.items():
237
- old_v = current_values.get(k, None)
238
- if old_v and old_v.cascade:
239
- v = old_v
240
- current_values[k] = v
241
- try:
242
- setattr(tls, tls_key, current_values)
243
- yield current_values
244
- finally:
245
- setattr(tls, tls_key, previous_values)
246
-
247
-
248
- def _get_scoped_value(
249
- tls: threading.local, tls_key: str, var_name: str, default: Any = None
250
- ) -> ContextualOverride:
251
- """Gets the value for requested variable from current scope."""
252
- scoped_values = getattr(tls, tls_key, {})
253
- return scoped_values.get(var_name, default)
254
-
255
-
256
- class ContextualAttribute(
257
- pg.symbolic.ValueFromParentChain, pg.views.HtmlTreeView.Extension
258
- ):
259
- """Attributes whose values are inferred from the context of the component.
260
-
261
- Please see go/langfun-component#attribute-value-retrieval for details.
262
- """
263
-
264
- NO_DEFAULT = (pg.MISSING_VALUE,)
265
-
266
- type: Annotated[Type[Any] | None, 'An optional type constraint.'] = None
267
-
268
- default: Any = NO_DEFAULT
269
-
270
- def value_from(
271
- self,
272
- parent,
273
- *,
274
- context_override: ContextualOverride | None = None,
275
- **kwargs,
276
- ):
277
- if parent not in (None, self.sym_parent) and isinstance(parent, Component):
278
- # Apply original search logic along the component containing chain.
279
- return super().value_from(parent, **kwargs)
280
- elif parent is None:
281
- # When there is no value inferred from the symbolic tree.
282
- # Search context override, and then attribute-level default.
283
- if context_override:
284
- return context_override.value
285
- if self.default == ContextualAttribute.NO_DEFAULT:
286
- return pg.MISSING_VALUE
287
- return self.default
288
- else:
289
- return pg.MISSING_VALUE
290
-
291
- def _html_tree_view_content(
292
- self,
293
- *,
294
- view: pg.views.HtmlTreeView,
295
- parent: Any = None,
296
- root_path: pg.KeyPath | None = None,
297
- **kwargs,
298
- ) -> pg.Html:
299
- inferred_value = pg.MISSING_VALUE
300
- if isinstance(parent, pg.Symbolic) and root_path:
301
- inferred_value = parent.sym_inferred(root_path.key, pg.MISSING_VALUE)
302
-
303
- if inferred_value is not pg.MISSING_VALUE:
304
- kwargs.pop('name', None)
305
- return view.render(
306
- inferred_value, parent=self,
307
- root_path=pg.KeyPath('<inferred>', root_path),
308
- **view.get_passthrough_kwargs(**kwargs)
309
- )
310
- return pg.Html.element(
311
- 'div',
312
- [
313
- '(not available)',
314
- ],
315
- css_classes=['unavailable-contextual'],
316
- )
317
-
318
- def _html_tree_view_config(self) -> dict[str, Any]:
319
- return pg.views.HtmlTreeView.get_kwargs(
320
- super()._html_tree_view_config(),
321
- dict(
322
- collapse_level=1,
323
- )
324
- )
325
-
326
- @classmethod
327
- def _html_tree_view_css_styles(cls) -> list[str]:
328
- return super()._html_tree_view_css_styles() + [
329
- """
330
- .contextual-attribute {
331
- color: purple;
332
- }
333
- .unavailable-contextual {
334
- color: gray;
335
- font-style: italic;
336
- }
337
- """
338
- ]
339
-
340
-
341
- # NOTE(daiyip): Returning Any instead of `lf.ContextualAttribute` to avoid
342
- # pytype check error as `contextual()` can be assigned to any type.
343
- def contextual(
344
- type: Type[Any] | None = None, # pylint: disable=redefined-builtin
345
- default: Any = ContextualAttribute.NO_DEFAULT,
346
- ) -> Any:
347
- """Value marker for a contextual attribute."""
348
- return ContextualAttribute(type=type, default=default, allow_partial=True)
349
-
350
-
351
- # Decorator for setting the positional arguments for Component.
352
- use_init_args = pg.use_init_args
@@ -73,26 +73,7 @@ class ComponentContextTest(unittest.TestCase):
73
73
  self.assertEqual(a1.y, 2)
74
74
  self.assertEqual(a1.z, -1)
75
75
 
76
- with lf.context(x=3, y=3, z=3) as parent_override:
77
- self.assertEqual(
78
- parent_override,
79
- dict(
80
- x=lf.ContextualOverride(3, cascade=False, override_attrs=False),
81
- y=lf.ContextualOverride(3, cascade=False, override_attrs=False),
82
- z=lf.ContextualOverride(3, cascade=False, override_attrs=False),
83
- ),
84
- )
85
- self.assertEqual(
86
- lf.get_contextual_override('y'),
87
- lf.ContextualOverride(3, cascade=False, override_attrs=False),
88
- )
89
- self.assertEqual(lf.context_value('x'), 3)
90
- self.assertIsNone(lf.context_value('f', None))
91
- with self.assertRaisesRegex(KeyError, '.* does not exist'):
92
- lf.context_value('f')
93
-
94
- self.assertEqual(lf.all_contextual_values(), dict(x=3, y=3, z=3))
95
-
76
+ with lf.context(x=3, y=3, z=3):
96
77
  # Member attributes take precedence over `lf.context`.
97
78
  self.assertEqual(a1.x, 1)
98
79
  self.assertEqual(a1.y, 2)
@@ -109,15 +90,7 @@ class ComponentContextTest(unittest.TestCase):
109
90
  self.assertEqual(a1.z, 3)
110
91
 
111
92
  # Test nested contextual override with override_attrs=True (default).
112
- with lf.context(y=4, z=4, override_attrs=True) as nested_override:
113
- self.assertEqual(
114
- nested_override,
115
- dict(
116
- x=lf.ContextualOverride(3, cascade=False, override_attrs=False),
117
- y=lf.ContextualOverride(4, cascade=False, override_attrs=True),
118
- z=lf.ContextualOverride(4, cascade=False, override_attrs=True),
119
- ),
120
- )
93
+ with lf.context(y=4, z=4, override_attrs=True):
121
94
 
122
95
  # Member attribute is not overriden as current scope does not override
123
96
  # `x``.
@@ -25,7 +25,6 @@ import threading
25
25
  import time
26
26
  from typing import Annotated, Any, Callable, Iterable, Iterator, Literal, Sequence, Tuple, Type, Union
27
27
 
28
- from langfun.core import component
29
28
  import pyglove as pg
30
29
 
31
30
 
@@ -39,18 +38,6 @@ except ImportError:
39
38
  tqdm = None
40
39
 
41
40
 
42
- def with_context_access(func: Callable[..., Any]) -> Callable[..., Any]:
43
- """Derives a user function with the access to the current context."""
44
- with component.context() as current_context:
45
- pass
46
-
47
- def _func(*args, **kwargs) -> Any:
48
- with component.context(**current_context):
49
- return func(*args, **kwargs)
50
-
51
- return _func
52
-
53
-
54
41
  class RetryError(RuntimeError):
55
42
  """Retry error."""
56
43
 
@@ -249,7 +236,8 @@ def concurrent_execute(
249
236
  try:
250
237
  executed_jobs = list(
251
238
  executor.map(
252
- lambda job: job(), [with_context_access(job) for job in jobs]
239
+ lambda job: job(),
240
+ [pg.with_contextual_override(job) for job in jobs]
253
241
  )
254
242
  )
255
243
  for job in executed_jobs:
@@ -736,9 +724,7 @@ def concurrent_map(
736
724
  retry_interval=retry_interval,
737
725
  exponential_backoff=exponential_backoff,
738
726
  )
739
- future = executor.submit(
740
- with_context_access(job),
741
- )
727
+ future = executor.submit(pg.with_contextual_override(job))
742
728
  pending_futures.append(future)
743
729
  future_to_job[future] = job
744
730
  total += 1
@@ -29,22 +29,6 @@ class A(component.Component):
29
29
  y: int = component.contextual()
30
30
 
31
31
 
32
- class WithContextAccessTest(unittest.TestCase):
33
-
34
- def test_context_access(self):
35
- inputs = [A(1), A(2)]
36
- with futures.ThreadPoolExecutor() as executor:
37
- with component.context(y=3):
38
- self.assertEqual(
39
- list(
40
- executor.map(
41
- concurrent.with_context_access(lambda x: x.y), inputs
42
- )
43
- ),
44
- [3, 3],
45
- )
46
-
47
-
48
32
  class RetryErrorTest(unittest.TestCase):
49
33
 
50
34
  def test_basics(self):
@@ -30,6 +30,9 @@ from langfun.core.llms.compositional import RandomChoice
30
30
  # REST-based models.
31
31
  from langfun.core.llms.rest import REST
32
32
 
33
+ # VertexAI-based models.
34
+ from langfun.core.llms.vertexai import VertexAI
35
+
33
36
  # Gemini models.
34
37
  from langfun.core.llms.google_genai import GenAI
35
38
  from langfun.core.llms.google_genai import GeminiFlash2_0ThinkingExp_20241219
@@ -44,7 +47,7 @@ from langfun.core.llms.google_genai import GeminiFlash1_5_002
44
47
  from langfun.core.llms.google_genai import GeminiFlash1_5_001
45
48
  from langfun.core.llms.google_genai import GeminiPro1
46
49
 
47
- from langfun.core.llms.vertexai import VertexAI
50
+ from langfun.core.llms.vertexai import VertexAIGemini
48
51
  from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0ThinkingExp_20241219
49
52
  from langfun.core.llms.vertexai import VertexAIGeminiFlash2_0Exp
50
53
  from langfun.core.llms.vertexai import VertexAIGeminiExp_20241206
@@ -114,6 +117,8 @@ from langfun.core.llms.openai import Gpt3Curie
114
117
  from langfun.core.llms.openai import Gpt3Babbage
115
118
  from langfun.core.llms.openai import Gpt3Ada
116
119
 
120
+ # Anthropic models.
121
+
117
122
  from langfun.core.llms.anthropic import Anthropic
118
123
  from langfun.core.llms.anthropic import Claude35Sonnet
119
124
  from langfun.core.llms.anthropic import Claude35Sonnet20241022
@@ -121,11 +126,14 @@ from langfun.core.llms.anthropic import Claude35Sonnet20240620
121
126
  from langfun.core.llms.anthropic import Claude3Opus
122
127
  from langfun.core.llms.anthropic import Claude3Sonnet
123
128
  from langfun.core.llms.anthropic import Claude3Haiku
124
- from langfun.core.llms.anthropic import VertexAIAnthropic
125
- from langfun.core.llms.anthropic import VertexAIClaude3_5_Sonnet_20241022
126
- from langfun.core.llms.anthropic import VertexAIClaude3_5_Sonnet_20240620
127
- from langfun.core.llms.anthropic import VertexAIClaude3_5_Haiku_20241022
128
- from langfun.core.llms.anthropic import VertexAIClaude3_Opus_20240229
129
+
130
+ from langfun.core.llms.vertexai import VertexAIAnthropic
131
+ from langfun.core.llms.vertexai import VertexAIClaude3_5_Sonnet_20241022
132
+ from langfun.core.llms.vertexai import VertexAIClaude3_5_Sonnet_20240620
133
+ from langfun.core.llms.vertexai import VertexAIClaude3_5_Haiku_20241022
134
+ from langfun.core.llms.vertexai import VertexAIClaude3_Opus_20240229
135
+
136
+ # Misc open source models.
129
137
 
130
138
  from langfun.core.llms.groq import Groq
131
139
  from langfun.core.llms.groq import GroqLlama3_2_3B
@@ -14,9 +14,8 @@
14
14
  """Language models from Anthropic."""
15
15
 
16
16
  import base64
17
- import functools
18
17
  import os
19
- from typing import Annotated, Any, Literal
18
+ from typing import Annotated, Any
20
19
 
21
20
  import langfun.core as lf
22
21
  from langfun.core import modalities as lf_modalities
@@ -24,20 +23,6 @@ from langfun.core.llms import rest
24
23
  import pyglove as pg
25
24
 
26
25
 
27
- try:
28
- # pylint: disable=g-import-not-at-top
29
- from google import auth as google_auth
30
- from google.auth import credentials as credentials_lib
31
- from google.auth.transport import requests as auth_requests
32
- Credentials = credentials_lib.Credentials
33
- # pylint: enable=g-import-not-at-top
34
- except ImportError:
35
- google_auth = None
36
- auth_requests = None
37
- credentials_lib = None
38
- Credentials = Any # pylint: disable=invalid-name
39
-
40
-
41
26
  SUPPORTED_MODELS_AND_SETTINGS = {
42
27
  # See https://docs.anthropic.com/claude/docs/models-overview
43
28
  # Rate limits from https://docs.anthropic.com/claude/reference/rate-limits
@@ -379,110 +364,3 @@ class Claude21(Anthropic):
379
364
  class ClaudeInstant(Anthropic):
380
365
  """Cheapest small and fast model, 100K context window."""
381
366
  model = 'claude-instant-1.2'
382
-
383
-
384
- #
385
- # Authropic models on VertexAI.
386
- #
387
-
388
-
389
- class VertexAIAnthropic(Anthropic):
390
- """Anthropic models on VertexAI."""
391
-
392
- project: Annotated[
393
- str | None,
394
- 'Google Cloud project ID.',
395
- ] = None
396
-
397
- location: Annotated[
398
- Literal['us-east5', 'europe-west1'],
399
- 'GCP location with Anthropic models hosted.'
400
- ] = 'us-east5'
401
-
402
- credentials: Annotated[
403
- Credentials | None, # pytype: disable=invalid-annotation
404
- (
405
- 'Credentials to use. If None, the default credentials '
406
- 'to the environment will be used.'
407
- ),
408
- ] = None
409
-
410
- api_version = 'vertex-2023-10-16'
411
-
412
- def _on_bound(self):
413
- super()._on_bound()
414
- if google_auth is None:
415
- raise ValueError(
416
- 'Please install "langfun[llm-google-vertex]" to use Vertex AI models.'
417
- )
418
- self._project = None
419
- self._credentials = None
420
-
421
- def _initialize(self):
422
- project = self.project or os.environ.get('VERTEXAI_PROJECT', None)
423
- if not project:
424
- raise ValueError(
425
- 'Please specify `project` during `__init__` or set environment '
426
- 'variable `VERTEXAI_PROJECT` with your Vertex AI project ID.'
427
- )
428
- self._project = project
429
- credentials = self.credentials
430
- if credentials is None:
431
- # Use default credentials.
432
- credentials = google_auth.default(
433
- scopes=['https://www.googleapis.com/auth/cloud-platform']
434
- )
435
- self._credentials = credentials
436
-
437
- @functools.cached_property
438
- def _session(self):
439
- assert self._api_initialized
440
- assert self._credentials is not None
441
- assert auth_requests is not None
442
- s = auth_requests.AuthorizedSession(self._credentials)
443
- s.headers.update(self.headers or {})
444
- return s
445
-
446
- @property
447
- def headers(self):
448
- return {
449
- 'Content-Type': 'application/json; charset=utf-8',
450
- }
451
-
452
- @property
453
- def api_endpoint(self) -> str:
454
- return (
455
- f'https://{self.location}-aiplatform.googleapis.com/v1/projects/'
456
- f'{self._project}/locations/{self.location}/publishers/anthropic/'
457
- f'models/{self.model}:streamRawPredict'
458
- )
459
-
460
- def request(
461
- self,
462
- prompt: lf.Message,
463
- sampling_options: lf.LMSamplingOptions
464
- ):
465
- request = super().request(prompt, sampling_options)
466
- request['anthropic_version'] = self.api_version
467
- del request['model']
468
- return request
469
-
470
-
471
- class VertexAIClaude3_Opus_20240229(VertexAIAnthropic): # pylint: disable=invalid-name
472
- """Anthropic's Claude 3 Opus model on VertexAI."""
473
- model = 'claude-3-opus@20240229'
474
-
475
-
476
- class VertexAIClaude3_5_Sonnet_20241022(VertexAIAnthropic): # pylint: disable=invalid-name
477
- """Anthropic's Claude 3.5 Sonnet model on VertexAI."""
478
- model = 'claude-3-5-sonnet-v2@20241022'
479
-
480
-
481
- class VertexAIClaude3_5_Sonnet_20240620(VertexAIAnthropic): # pylint: disable=invalid-name
482
- """Anthropic's Claude 3.5 Sonnet model on VertexAI."""
483
- model = 'claude-3-5-sonnet@20240620'
484
-
485
-
486
- class VertexAIClaude3_5_Haiku_20241022(VertexAIAnthropic): # pylint: disable=invalid-name
487
- """Anthropic's Claude 3.5 Haiku model on VertexAI."""
488
- model = 'claude-3-5-haiku@20241022'
@@ -19,9 +19,6 @@ from typing import Any
19
19
  import unittest
20
20
  from unittest import mock
21
21
 
22
- from google.auth import exceptions
23
- from langfun.core import language_model
24
- from langfun.core import message as lf_message
25
22
  from langfun.core import modalities as lf_modalities
26
23
  from langfun.core.llms import anthropic
27
24
  import pyglove as pg
@@ -186,50 +183,5 @@ class AnthropicTest(unittest.TestCase):
186
183
  lm('hello', max_attempts=1)
187
184
 
188
185
 
189
- class VertexAIAnthropicTest(unittest.TestCase):
190
- """Tests for VertexAI Anthropic models."""
191
-
192
- def test_basics(self):
193
- with self.assertRaisesRegex(ValueError, 'Please specify `project`'):
194
- lm = anthropic.VertexAIClaude3_5_Sonnet_20241022()
195
- lm('hi')
196
-
197
- model = anthropic.VertexAIClaude3_5_Sonnet_20241022(project='langfun')
198
-
199
- # NOTE(daiyip): For OSS users, default credentials are not available unless
200
- # users have already set up their GCP project. Therefore we ignore the
201
- # exception here.
202
- try:
203
- model._initialize()
204
- except exceptions.DefaultCredentialsError:
205
- pass
206
-
207
- self.assertEqual(
208
- model.api_endpoint,
209
- (
210
- 'https://us-east5-aiplatform.googleapis.com/v1/projects/'
211
- 'langfun/locations/us-east5/publishers/anthropic/'
212
- 'models/claude-3-5-sonnet-v2@20241022:streamRawPredict'
213
- )
214
- )
215
- request = model.request(
216
- lf_message.UserMessage('hi'),
217
- language_model.LMSamplingOptions(temperature=0.0),
218
- )
219
- self.assertEqual(
220
- request,
221
- {
222
- 'anthropic_version': 'vertex-2023-10-16',
223
- 'max_tokens': 8192,
224
- 'messages': [
225
- {'content': [{'text': 'hi', 'type': 'text'}], 'role': 'user'}
226
- ],
227
- 'stream': False,
228
- 'temperature': 0.0,
229
- 'top_k': 40,
230
- },
231
- )
232
-
233
-
234
186
  if __name__ == '__main__':
235
187
  unittest.main()
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The Langfun Authors
1
+ # Copyright 2025 The Langfun Authors
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -15,10 +15,12 @@
15
15
 
16
16
  import functools
17
17
  import os
18
- from typing import Annotated, Any
18
+ from typing import Annotated, Any, Literal
19
19
 
20
20
  import langfun.core as lf
21
+ from langfun.core.llms import anthropic
21
22
  from langfun.core.llms import gemini
23
+ from langfun.core.llms import rest
22
24
  import pyglove as pg
23
25
 
24
26
  try:
@@ -36,10 +38,21 @@ except ImportError:
36
38
  Credentials = Any
37
39
 
38
40
 
39
- @lf.use_init_args(['model'])
40
- @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
41
- class VertexAI(gemini.Gemini):
42
- """Language model served on VertexAI with REST API."""
41
+ @pg.use_init_args(['api_endpoint'])
42
+ class VertexAI(rest.REST):
43
+ """Base class for VertexAI models.
44
+
45
+ This class handles the authentication of vertex AI models. Subclasses
46
+ should implement `request` and `result` methods, as well as the `api_endpoint`
47
+ property. Or let users to provide them as __init__ arguments.
48
+
49
+ Please check out VertexAIGemini in `gemini.py` as an example.
50
+ """
51
+
52
+ model: Annotated[
53
+ str | None,
54
+ 'Model ID.'
55
+ ] = None
43
56
 
44
57
  project: Annotated[
45
58
  str | None,
@@ -114,6 +127,17 @@ class VertexAI(gemini.Gemini):
114
127
  s.headers.update(self.headers or {})
115
128
  return s
116
129
 
130
+
131
+ #
132
+ # Gemini models served by Vertex AI.
133
+ #
134
+
135
+
136
+ @pg.use_init_args(['model'])
137
+ @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
138
+ class VertexAIGemini(VertexAI, gemini.Gemini):
139
+ """Gemini models served by Vertex AI.."""
140
+
117
141
  @property
118
142
  def api_endpoint(self) -> str:
119
143
  assert self._api_initialized
@@ -124,7 +148,7 @@ class VertexAI(gemini.Gemini):
124
148
  )
125
149
 
126
150
 
127
- class VertexAIGeminiFlash2_0ThinkingExp_20241219(VertexAI): # pylint: disable=invalid-name
151
+ class VertexAIGeminiFlash2_0ThinkingExp_20241219(VertexAIGemini): # pylint: disable=invalid-name
128
152
  """Vertex AI Gemini Flash 2.0 Thinking model launched on 12/19/2024."""
129
153
 
130
154
  api_version = 'v1alpha'
@@ -132,61 +156,128 @@ class VertexAIGeminiFlash2_0ThinkingExp_20241219(VertexAI): # pylint: disable=i
132
156
  timeout = None
133
157
 
134
158
 
135
- class VertexAIGeminiFlash2_0Exp(VertexAI): # pylint: disable=invalid-name
159
+ class VertexAIGeminiFlash2_0Exp(VertexAIGemini): # pylint: disable=invalid-name
136
160
  """Vertex AI Gemini 2.0 Flash model."""
137
161
 
138
162
  model = 'gemini-2.0-flash-exp'
139
163
 
140
164
 
141
- class VertexAIGeminiExp_20241206(VertexAI): # pylint: disable=invalid-name
165
+ class VertexAIGeminiExp_20241206(VertexAIGemini): # pylint: disable=invalid-name
142
166
  """Vertex AI Gemini Experimental model launched on 12/06/2024."""
143
167
 
144
168
  model = 'gemini-exp-1206'
145
169
 
146
170
 
147
- class VertexAIGeminiExp_20241114(VertexAI): # pylint: disable=invalid-name
171
+ class VertexAIGeminiExp_20241114(VertexAIGemini): # pylint: disable=invalid-name
148
172
  """Vertex AI Gemini Experimental model launched on 11/14/2024."""
149
173
 
150
174
  model = 'gemini-exp-1114'
151
175
 
152
176
 
153
- class VertexAIGeminiPro1_5(VertexAI): # pylint: disable=invalid-name
177
+ class VertexAIGeminiPro1_5(VertexAIGemini): # pylint: disable=invalid-name
154
178
  """Vertex AI Gemini 1.5 Pro model."""
155
179
 
156
180
  model = 'gemini-1.5-pro-latest'
157
181
 
158
182
 
159
- class VertexAIGeminiPro1_5_002(VertexAI): # pylint: disable=invalid-name
183
+ class VertexAIGeminiPro1_5_002(VertexAIGemini): # pylint: disable=invalid-name
160
184
  """Vertex AI Gemini 1.5 Pro model."""
161
185
 
162
186
  model = 'gemini-1.5-pro-002'
163
187
 
164
188
 
165
- class VertexAIGeminiPro1_5_001(VertexAI): # pylint: disable=invalid-name
189
+ class VertexAIGeminiPro1_5_001(VertexAIGemini): # pylint: disable=invalid-name
166
190
  """Vertex AI Gemini 1.5 Pro model."""
167
191
 
168
192
  model = 'gemini-1.5-pro-001'
169
193
 
170
194
 
171
- class VertexAIGeminiFlash1_5(VertexAI): # pylint: disable=invalid-name
195
+ class VertexAIGeminiFlash1_5(VertexAIGemini): # pylint: disable=invalid-name
172
196
  """Vertex AI Gemini 1.5 Flash model."""
173
197
 
174
198
  model = 'gemini-1.5-flash'
175
199
 
176
200
 
177
- class VertexAIGeminiFlash1_5_002(VertexAI): # pylint: disable=invalid-name
201
+ class VertexAIGeminiFlash1_5_002(VertexAIGemini): # pylint: disable=invalid-name
178
202
  """Vertex AI Gemini 1.5 Flash model."""
179
203
 
180
204
  model = 'gemini-1.5-flash-002'
181
205
 
182
206
 
183
- class VertexAIGeminiFlash1_5_001(VertexAI): # pylint: disable=invalid-name
207
+ class VertexAIGeminiFlash1_5_001(VertexAIGemini): # pylint: disable=invalid-name
184
208
  """Vertex AI Gemini 1.5 Flash model."""
185
209
 
186
210
  model = 'gemini-1.5-flash-001'
187
211
 
188
212
 
189
- class VertexAIGeminiPro1(VertexAI): # pylint: disable=invalid-name
213
+ class VertexAIGeminiPro1(VertexAIGemini): # pylint: disable=invalid-name
190
214
  """Vertex AI Gemini 1.0 Pro model."""
191
215
 
192
216
  model = 'gemini-1.0-pro'
217
+
218
+
219
+ #
220
+ # Anthropic models on Vertex AI.
221
+ #
222
+
223
+
224
+ @pg.use_init_args(['model'])
225
+ @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
226
+ class VertexAIAnthropic(VertexAI, anthropic.Anthropic):
227
+ """Anthropic models on VertexAI."""
228
+
229
+ location: Annotated[
230
+ Literal['us-east5', 'europe-west1'],
231
+ 'GCP location with Anthropic models hosted.'
232
+ ] = 'us-east5'
233
+
234
+ api_version = 'vertex-2023-10-16'
235
+
236
+ @property
237
+ def headers(self):
238
+ return {
239
+ 'Content-Type': 'application/json; charset=utf-8',
240
+ }
241
+
242
+ @property
243
+ def api_endpoint(self) -> str:
244
+ return (
245
+ f'https://{self.location}-aiplatform.googleapis.com/v1/projects/'
246
+ f'{self._project}/locations/{self.location}/publishers/anthropic/'
247
+ f'models/{self.model}:streamRawPredict'
248
+ )
249
+
250
+ def request(
251
+ self,
252
+ prompt: lf.Message,
253
+ sampling_options: lf.LMSamplingOptions
254
+ ):
255
+ request = super().request(prompt, sampling_options)
256
+ request['anthropic_version'] = self.api_version
257
+ del request['model']
258
+ return request
259
+
260
+
261
+ # pylint: disable=invalid-name
262
+
263
+
264
+ class VertexAIClaude3_Opus_20240229(VertexAIAnthropic):
265
+ """Anthropic's Claude 3 Opus model on VertexAI."""
266
+ model = 'claude-3-opus@20240229'
267
+
268
+
269
+ class VertexAIClaude3_5_Sonnet_20241022(VertexAIAnthropic):
270
+ """Anthropic's Claude 3.5 Sonnet model on VertexAI."""
271
+ model = 'claude-3-5-sonnet-v2@20241022'
272
+
273
+
274
+ class VertexAIClaude3_5_Sonnet_20240620(VertexAIAnthropic):
275
+ """Anthropic's Claude 3.5 Sonnet model on VertexAI."""
276
+ model = 'claude-3-5-sonnet@20240620'
277
+
278
+
279
+ class VertexAIClaude3_5_Haiku_20241022(VertexAIAnthropic):
280
+ """Anthropic's Claude 3.5 Haiku model on VertexAI."""
281
+ model = 'claude-3-5-haiku@20241022'
282
+
283
+ # pylint: enable=invalid-name
@@ -17,6 +17,8 @@ import os
17
17
  import unittest
18
18
  from unittest import mock
19
19
 
20
+ from google.auth import exceptions
21
+ import langfun.core as lf
20
22
  from langfun.core.llms import vertexai
21
23
 
22
24
 
@@ -48,5 +50,55 @@ class VertexAITest(unittest.TestCase):
48
50
  del os.environ['VERTEXAI_LOCATION']
49
51
 
50
52
 
53
+ class VertexAIAnthropicTest(unittest.TestCase):
54
+ """Tests for VertexAI Anthropic models."""
55
+
56
+ def test_basics(self):
57
+ with self.assertRaisesRegex(ValueError, 'Please specify `project`'):
58
+ lm = vertexai.VertexAIClaude3_5_Sonnet_20241022()
59
+ lm('hi')
60
+
61
+ model = vertexai.VertexAIClaude3_5_Sonnet_20241022(project='langfun')
62
+
63
+ # NOTE(daiyip): For OSS users, default credentials are not available unless
64
+ # users have already set up their GCP project. Therefore we ignore the
65
+ # exception here.
66
+ try:
67
+ model._initialize()
68
+ except exceptions.DefaultCredentialsError:
69
+ pass
70
+
71
+ self.assertEqual(
72
+ model.api_endpoint,
73
+ (
74
+ 'https://us-east5-aiplatform.googleapis.com/v1/projects/'
75
+ 'langfun/locations/us-east5/publishers/anthropic/'
76
+ 'models/claude-3-5-sonnet-v2@20241022:streamRawPredict'
77
+ )
78
+ )
79
+ self.assertEqual(
80
+ model.headers,
81
+ {
82
+ 'Content-Type': 'application/json; charset=utf-8',
83
+ },
84
+ )
85
+ request = model.request(
86
+ lf.UserMessage('hi'), lf.LMSamplingOptions(temperature=0.0),
87
+ )
88
+ self.assertEqual(
89
+ request,
90
+ {
91
+ 'anthropic_version': 'vertex-2023-10-16',
92
+ 'max_tokens': 8192,
93
+ 'messages': [
94
+ {'content': [{'text': 'hi', 'type': 'text'}], 'role': 'user'}
95
+ ],
96
+ 'stream': False,
97
+ 'temperature': 0.0,
98
+ 'top_k': 40,
99
+ },
100
+ )
101
+
102
+
51
103
  if __name__ == '__main__':
52
104
  unittest.main()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: langfun
3
- Version: 0.1.2.dev202501150804
3
+ Version: 0.1.2.dev202501170804
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -1,9 +1,9 @@
1
1
  langfun/__init__.py,sha256=fhfPXpHN7GoGqixpFfqhQkYxFs_siP_LhbjZhd3lhio,2497
2
- langfun/core/__init__.py,sha256=oo81OUA4Xy2q7icxcxAMicVzbla5gMDf9Nzw_iIAkis,4627
3
- langfun/core/component.py,sha256=HVrEoTL1Y01iqOHC3FYdbAOnffqfHHtGJXoK1vkdEwo,11583
4
- langfun/core/component_test.py,sha256=sG-T2wpvBfHqWGZE7sc4NayJj2aj5QFBzSwFiwrGEIc,10376
5
- langfun/core/concurrent.py,sha256=4CqOlH9YrtrK-Jijp09hLX7AYxRYL8slpy-tL2aqnIM,33121
6
- langfun/core/concurrent_test.py,sha256=0ucEYxPgTErOOK1TIN-TSqZphHS1Vl1VssyQP49cpOI,17997
2
+ langfun/core/__init__.py,sha256=Cy3BU8R9VHnJmbxiN9QXXBb1K09ZoESYIJGf_uwJrDs,4571
3
+ langfun/core/component.py,sha256=g1kQM0bryYYYWVDrSMnHfc74wIBbpfe5_B3s-UIP5GE,3028
4
+ langfun/core/component_test.py,sha256=0CxTgjAud3aj8wBauFhG2FHDqrxCTl4OI4gzQTad-40,9254
5
+ langfun/core/concurrent.py,sha256=zY-pXqlGqss_GI20tM1gXvyW8QepVPUuFNmutcIdhbI,32760
6
+ langfun/core/concurrent_test.py,sha256=rc5T-2giWgtbwNuN6gmei7Uwo66HsJeeRtXZCpya_QU,17590
7
7
  langfun/core/console.py,sha256=V_mOiFi9oGh8gLsUeR56pdFDkuvYOpvQt7DY1KUTWTA,2535
8
8
  langfun/core/console_test.py,sha256=pBOcuNMJdVELywvroptfcRtJMsegMm3wSlHAL2TdxVk,1679
9
9
  langfun/core/langfunc.py,sha256=G50YgoVZ0y1GFw2ev41MlOqr6qa8YakbvNC0h_E0PiA,11140
@@ -71,9 +71,9 @@ langfun/core/eval/v2/reporting.py,sha256=QOp5jX761Esvi5w_UIRLDqPY_XRO6ru02-DOrdq
71
71
  langfun/core/eval/v2/reporting_test.py,sha256=UmYSAQvD3AIXsSyWQ-WD2uLtEISYpmBeoKY5u5Qwc8E,5696
72
72
  langfun/core/eval/v2/runners.py,sha256=DKEmSlGXjOXKWFdBhTpLy7tMsBHZHd1Brl3hWIngsSQ,15931
73
73
  langfun/core/eval/v2/runners_test.py,sha256=A37fKK2MvAVTiShsg_laluJzJ9AuAQn52k7HPbfD0Ks,11666
74
- langfun/core/llms/__init__.py,sha256=Ntr0kvHc17VEZ5EV9fCoYY1kzRvQxCoZrtDRYNiMWCs,6742
75
- langfun/core/llms/anthropic.py,sha256=a5MmnFsBA0CbfvwzXT1v_0fqLRMrhUNdh1tx6469PQ4,14357
76
- langfun/core/llms/anthropic_test.py,sha256=-2U4kc_pgBM7wqxu8RuxzyHPGww1EAWqKUvN4PW8Btw,8058
74
+ langfun/core/llms/__init__.py,sha256=wA6t_E3peTYTjsW6uOHnOs9wjQ_Tj1WYlhVVLk2Sjcg,6867
75
+ langfun/core/llms/anthropic.py,sha256=z_DWDpR1VKNzv6wq-9CXLzWdqCDXRKuVFacJNpgBqAs,10826
76
+ langfun/core/llms/anthropic_test.py,sha256=zZ2eSP8hhVv-RDSWxT7wX-NS5DfGfQmCjS9P0pusAHM,6556
77
77
  langfun/core/llms/compositional.py,sha256=csW_FLlgL-tpeyCOTVvfUQkMa_zCN5Y2I-YbSNuK27U,2872
78
78
  langfun/core/llms/compositional_test.py,sha256=4eTnOer-DncRKGaIJW2ZQQMLnt5r2R0UIx_DYOvGAQo,2027
79
79
  langfun/core/llms/deepseek.py,sha256=Y7DlLUWrukbPVyBMesppd-m75Q-PxD0b3KnMKaoY_8I,3744
@@ -94,8 +94,8 @@ langfun/core/llms/openai_compatible_test.py,sha256=0uFYhCiuHo2Wrlgj16-GRG6rW8P6E
94
94
  langfun/core/llms/openai_test.py,sha256=m85YjGCvWvV5ZYagjC0FqI0FcqyCEVCbUUs8Wm3iUrc,2475
95
95
  langfun/core/llms/rest.py,sha256=sWbYUV8S3SuOg9giq7xwD-xDRfaF7NP_ig7bI52-Rj4,3442
96
96
  langfun/core/llms/rest_test.py,sha256=zWGiI08f9gXsoQPJS9TlX1zD2uQLrJUB-1VpAJXRHfs,3475
97
- langfun/core/llms/vertexai.py,sha256=MuwLPTJ6-9x2uRDCSM1_biPK6M76FFlL1ezf5OmobDA,5504
98
- langfun/core/llms/vertexai_test.py,sha256=iXjmQs7TNiwcueoaRGpdp4KnASkDJaTP__Z9QroN8zQ,1787
97
+ langfun/core/llms/vertexai.py,sha256=JV9iHsCM3Ee-4nE1ENNkTXIYGxjCHxrEeir175YpCM8,7869
98
+ langfun/core/llms/vertexai_test.py,sha256=6eLQOyeL5iGZOIWb39sFcf1TgYD_6TBGYdMO4UIvhf4,3333
99
99
  langfun/core/llms/cache/__init__.py,sha256=QAo3InUMDM_YpteNnVCSejI4zOsnjSMWKJKzkb3VY64,993
100
100
  langfun/core/llms/cache/base.py,sha256=rt3zwmyw0y9jsSGW-ZbV1vAfLxQ7_3AVk0l2EySlse4,3918
101
101
  langfun/core/llms/cache/in_memory.py,sha256=i58oiQL28RDsq37dwqgVpC2mBETJjIEFS20yHiV5MKU,5185
@@ -146,8 +146,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
146
146
  langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
147
147
  langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
148
148
  langfun/core/templates/selfplay_test.py,sha256=Ot__1P1M8oJfoTp-M9-PQ6HUXqZKyMwvZ5f7yQ3yfyM,2326
149
- langfun-0.1.2.dev202501150804.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
150
- langfun-0.1.2.dev202501150804.dist-info/METADATA,sha256=HAqG3RU6S4kwv8Knx7vDVp01x8Kuc7It1QnYg0oeol8,8172
151
- langfun-0.1.2.dev202501150804.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
152
- langfun-0.1.2.dev202501150804.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
153
- langfun-0.1.2.dev202501150804.dist-info/RECORD,,
149
+ langfun-0.1.2.dev202501170804.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
150
+ langfun-0.1.2.dev202501170804.dist-info/METADATA,sha256=X3MDNl6D6StuwltvUclYhE20uKNQ2x8lY3CkPggJyI4,8172
151
+ langfun-0.1.2.dev202501170804.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
152
+ langfun-0.1.2.dev202501170804.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
153
+ langfun-0.1.2.dev202501170804.dist-info/RECORD,,