langfun 0.1.2.dev202412020805__py3-none-any.whl → 0.1.2.dev202412030000__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,34 +13,14 @@
13
13
  # limitations under the License.
14
14
  """Language models from OpenAI."""
15
15
 
16
- import collections
17
- import functools
18
16
  import os
19
17
  from typing import Annotated, Any
20
18
 
21
19
  import langfun.core as lf
22
20
  from langfun.core import modalities as lf_modalities
21
+ from langfun.core.llms import rest
23
22
  import pyglove as pg
24
23
 
25
- try:
26
- import openai # pylint: disable=g-import-not-at-top
27
-
28
- if hasattr(openai, 'error'):
29
- # For lower versions.
30
- ServiceUnavailableError = openai.error.ServiceUnavailableError
31
- RateLimitError = openai.error.RateLimitError
32
- APITimeoutError = (
33
- openai.error.APIError,
34
- '.*The server had an error processing your request'
35
- )
36
- else:
37
- # For higher versions.
38
- ServiceUnavailableError = getattr(openai, 'InternalServerError')
39
- RateLimitError = getattr(openai, 'RateLimitError')
40
- APITimeoutError = getattr(openai, 'APITimeoutError')
41
- except ImportError:
42
- openai = None
43
-
44
24
 
45
25
  # From https://platform.openai.com/settings/organization/limits
46
26
  _DEFAULT_TPM = 250000
@@ -289,7 +269,7 @@ SUPPORTED_MODELS_AND_SETTINGS = {
289
269
  rpm=_DEFAULT_RPM,
290
270
  tpm=_DEFAULT_TPM
291
271
  ),
292
- # GPT-3 instruction-tuned models
272
+ # GPT-3 instruction-tuned models (Deprecated)
293
273
  'text-curie-001': pg.Dict(
294
274
  in_service=False,
295
275
  rpm=_DEFAULT_RPM,
@@ -325,9 +305,9 @@ SUPPORTED_MODELS_AND_SETTINGS = {
325
305
  rpm=_DEFAULT_RPM,
326
306
  tpm=_DEFAULT_TPM
327
307
  ),
328
- # GPT-3 base models
308
+ # GPT-3 base models that are still in service.
329
309
  'babbage-002': pg.Dict(
330
- in_service=False,
310
+ in_service=True,
331
311
  rpm=_DEFAULT_RPM,
332
312
  tpm=_DEFAULT_TPM
333
313
  ),
@@ -340,7 +320,7 @@ SUPPORTED_MODELS_AND_SETTINGS = {
340
320
 
341
321
 
342
322
  @lf.use_init_args(['model'])
343
- class OpenAI(lf.LanguageModel):
323
+ class OpenAI(rest.REST):
344
324
  """OpenAI model."""
345
325
 
346
326
  model: pg.typing.Annotated[
@@ -348,7 +328,9 @@ class OpenAI(lf.LanguageModel):
348
328
  pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
349
329
  ),
350
330
  'The name of the model to use.',
351
- ] = 'gpt-3.5-turbo'
331
+ ]
332
+
333
+ api_endpoint: str = 'https://api.openai.com/v1/chat/completions'
352
334
 
353
335
  multimodal: Annotated[
354
336
  bool,
@@ -372,27 +354,45 @@ class OpenAI(lf.LanguageModel):
372
354
  ),
373
355
  ] = None
374
356
 
357
+ project: Annotated[
358
+ str | None,
359
+ (
360
+ 'Project. If None, the key will be read from environment '
361
+ "variable 'OPENAI_PROJECT'. Based on the value, usages from "
362
+ "these API requests will count against the project's quota. "
363
+ ),
364
+ ] = None
365
+
375
366
  def _on_bound(self):
376
367
  super()._on_bound()
377
- self.__dict__.pop('_api_initialized', None)
378
- if openai is None:
379
- raise RuntimeError(
380
- 'Please install "langfun[llm-openai]" to use OpenAI models.'
381
- )
368
+ self._api_key = None
369
+ self._organization = None
370
+ self._project = None
382
371
 
383
- @functools.cached_property
384
- def _api_initialized(self):
372
+ def _initialize(self):
385
373
  api_key = self.api_key or os.environ.get('OPENAI_API_KEY', None)
386
374
  if not api_key:
387
375
  raise ValueError(
388
376
  'Please specify `api_key` during `__init__` or set environment '
389
377
  'variable `OPENAI_API_KEY` with your OpenAI API key.'
390
378
  )
391
- openai.api_key = api_key
392
- org = self.organization or os.environ.get('OPENAI_ORGANIZATION', None)
393
- if org:
394
- openai.organization = org
395
- return True
379
+ self._api_key = api_key
380
+ self._organization = self.organization or os.environ.get(
381
+ 'OPENAI_ORGANIZATION', None
382
+ )
383
+ self._project = self.project or os.environ.get('OPENAI_PROJECT', None)
384
+
385
+ @property
386
+ def headers(self) -> dict[str, Any]:
387
+ headers = {
388
+ 'Content-Type': 'application/json',
389
+ 'Authorization': f'Bearer {self._api_key}',
390
+ }
391
+ if self._organization:
392
+ headers['OpenAI-Organization'] = self._organization
393
+ if self._project:
394
+ headers['OpenAI-Project'] = self._project
395
+ return headers
396
396
 
397
397
  @property
398
398
  def model_id(self) -> str:
@@ -428,23 +428,16 @@ class OpenAI(lf.LanguageModel):
428
428
 
429
429
  @classmethod
430
430
  def dir(cls):
431
- assert openai is not None
432
- return openai.Model.list()
431
+ return [k for k, v in SUPPORTED_MODELS_AND_SETTINGS.items() if v.in_service]
433
432
 
434
- @property
435
- def is_chat_model(self):
436
- """Returns True if the model is a chat model."""
437
- return self.model.startswith(('o1', 'gpt-4', 'gpt-3.5-turbo'))
438
-
439
- def _get_request_args(
433
+ def _request_args(
440
434
  self, options: lf.LMSamplingOptions) -> dict[str, Any]:
441
435
  # Reference:
442
436
  # https://platform.openai.com/docs/api-reference/completions/create
443
437
  # NOTE(daiyip): options.top_k is not applicable.
444
438
  args = dict(
439
+ model=self.model,
445
440
  n=options.n,
446
- stream=False,
447
- timeout=self.timeout,
448
441
  top_logprobs=options.top_logprobs,
449
442
  )
450
443
  if options.logprobs:
@@ -453,13 +446,10 @@ class OpenAI(lf.LanguageModel):
453
446
  raise RuntimeError('`logprobs` is not supported on {self.model!r}.')
454
447
  args['logprobs'] = options.logprobs
455
448
 
456
- # Completion and ChatCompletion uses different parameter name for model.
457
- args['model' if self.is_chat_model else 'engine'] = self.model
458
-
459
449
  if options.temperature is not None:
460
450
  args['temperature'] = options.temperature
461
451
  if options.max_tokens is not None:
462
- args['max_tokens'] = options.max_tokens
452
+ args['max_completion_tokens'] = options.max_tokens
463
453
  if options.top_p is not None:
464
454
  args['top_p'] = options.top_p
465
455
  if options.stop:
@@ -468,168 +458,113 @@ class OpenAI(lf.LanguageModel):
468
458
  args['seed'] = options.random_seed
469
459
  return args
470
460
 
471
- def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
472
- assert self._api_initialized
473
- if self.is_chat_model:
474
- return self._chat_complete_batch(prompts)
475
- else:
476
- return self._complete_batch(prompts)
477
-
478
- def _complete_batch(
479
- self, prompts: list[lf.Message]
480
- ) -> list[lf.LMSamplingResult]:
481
-
482
- def _open_ai_completion(prompts):
483
- assert openai is not None
484
- response = openai.Completion.create(
485
- prompt=[p.text for p in prompts],
486
- **self._get_request_args(self.sampling_options),
487
- )
488
- # Parse response.
489
- samples_by_index = collections.defaultdict(list)
490
- for choice in response.choices:
491
- samples_by_index[choice.index].append(
492
- lf.LMSample(choice.text.strip(), score=choice.logprobs or 0.0)
493
- )
494
-
495
- n = len(samples_by_index)
496
- estimated_cost = self.estimate_cost(
497
- num_input_tokens=response.usage.prompt_tokens,
498
- num_output_tokens=response.usage.completion_tokens,
499
- )
500
- usage = lf.LMSamplingUsage(
501
- prompt_tokens=response.usage.prompt_tokens // n,
502
- completion_tokens=response.usage.completion_tokens // n,
503
- total_tokens=response.usage.total_tokens // n,
504
- estimated_cost=(
505
- None if estimated_cost is None else (estimated_cost // n)
506
- )
507
- )
508
- return [
509
- lf.LMSamplingResult(samples_by_index[index], usage=usage)
510
- for index in sorted(samples_by_index.keys())
511
- ]
512
-
513
- return self._parallel_execute_with_currency_control(
514
- _open_ai_completion,
515
- [prompts],
516
- retry_on_errors=(
517
- ServiceUnavailableError,
518
- RateLimitError,
519
- APITimeoutError,
520
- ),
521
- )[0]
522
-
523
- def _chat_complete_batch(
524
- self, prompts: list[lf.Message]
525
- ) -> list[lf.LMSamplingResult]:
526
- def _content_from_message(message: lf.Message):
527
- if self.multimodal:
528
- content = []
529
- for chunk in message.chunk():
530
- if isinstance(chunk, str):
531
- item = dict(type='text', text=chunk)
532
- elif isinstance(chunk, lf_modalities.Image):
533
- if chunk.uri and chunk.uri.lower().startswith(
534
- ('http:', 'https:', 'ftp:')
535
- ):
536
- uri = chunk.uri
537
- else:
538
- uri = chunk.content_uri
539
- item = dict(type='image_url', image_url=dict(url=uri))
540
- else:
541
- raise ValueError(f'Unsupported modality object: {chunk!r}.')
542
- content.append(item)
461
+ def _content_from_message(self, message: lf.Message):
462
+ """Returns a OpenAI content object from a Langfun message."""
463
+ def _uri_from(chunk: lf.Modality) -> str:
464
+ if chunk.uri and chunk.uri.lower().startswith(
465
+ ('http:', 'https:', 'ftp:')
466
+ ):
467
+ return chunk.uri
468
+ return chunk.content_uri
469
+
470
+ content = []
471
+ for chunk in message.chunk():
472
+ if isinstance(chunk, str):
473
+ item = dict(type='text', text=chunk)
474
+ elif isinstance(chunk, lf_modalities.Image) and self.multimodal:
475
+ item = dict(type='image_url', image_url=dict(url=_uri_from(chunk)))
543
476
  else:
544
- content = message.text
545
- return content
546
-
547
- def _open_ai_chat_completion(prompt: lf.Message):
548
- request_args = self._get_request_args(self.sampling_options)
549
- # Users could use `metadata_json_schema` to pass additional
550
- # request arguments.
551
- json_schema = prompt.metadata.get('json_schema')
552
- if json_schema is not None:
553
- if not isinstance(json_schema, dict):
554
- raise ValueError(
555
- f'`json_schema` must be a dict, got {json_schema!r}.'
556
- )
557
- if 'title' not in json_schema:
558
- raise ValueError(
559
- f'The root of `json_schema` must have a `title` field, '
560
- f'got {json_schema!r}.'
561
- )
562
- request_args.update(
563
- response_format=dict(
564
- type='json_schema',
565
- json_schema=dict(
566
- schema=json_schema,
567
- name=json_schema['title'],
568
- strict=True,
569
- )
570
- )
571
- )
572
- prompt.metadata.formatted_text = (
573
- prompt.text
574
- + '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
575
- + pg.to_json_str(request_args['response_format'], json_indent=2)
576
- )
477
+ raise ValueError(f'Unsupported modality: {chunk!r}.')
478
+ content.append(item)
479
+ return content
577
480
 
578
- # Prepare messages.
579
- messages = []
580
- # Users could use `metadata_system_message` to pass system message.
581
- system_message = prompt.metadata.get('system_message')
582
- if system_message:
583
- system_message = lf.SystemMessage.from_value(system_message)
584
- messages.append(
585
- dict(role='system', content=_content_from_message(system_message))
481
+ def request(
482
+ self,
483
+ prompt: lf.Message,
484
+ sampling_options: lf.LMSamplingOptions
485
+ ) -> dict[str, Any]:
486
+ """Returns the JSON input for a message."""
487
+ request_args = self._request_args(sampling_options)
488
+
489
+ # Users could use `metadata_json_schema` to pass additional
490
+ # request arguments.
491
+ json_schema = prompt.metadata.get('json_schema')
492
+ if json_schema is not None:
493
+ if not isinstance(json_schema, dict):
494
+ raise ValueError(
495
+ f'`json_schema` must be a dict, got {json_schema!r}.'
586
496
  )
587
- messages.append(dict(role='user', content=_content_from_message(prompt)))
588
-
589
- assert openai is not None
590
- response = openai.ChatCompletion.create(messages=messages, **request_args)
591
-
592
- samples = []
593
- for choice in response.choices:
594
- logprobs = None
595
- choice_logprobs = getattr(choice, 'logprobs', None)
596
- if choice_logprobs:
597
- logprobs = [
598
- (
599
- t.token,
600
- t.logprob,
601
- [(tt.token, tt.logprob) for tt in t.top_logprobs],
602
- )
603
- for t in choice_logprobs.content
604
- ]
605
- samples.append(
606
- lf.LMSample(
607
- choice.message.content,
608
- score=0.0,
609
- logprobs=logprobs,
610
- )
497
+ if 'title' not in json_schema:
498
+ raise ValueError(
499
+ f'The root of `json_schema` must have a `title` field, '
500
+ f'got {json_schema!r}.'
611
501
  )
612
-
613
- return lf.LMSamplingResult(
614
- samples=samples,
615
- usage=lf.LMSamplingUsage(
616
- prompt_tokens=response.usage.prompt_tokens,
617
- completion_tokens=response.usage.completion_tokens,
618
- total_tokens=response.usage.total_tokens,
619
- estimated_cost=self.estimate_cost(
620
- num_input_tokens=response.usage.prompt_tokens,
621
- num_output_tokens=response.usage.completion_tokens,
502
+ request_args.update(
503
+ response_format=dict(
504
+ type='json_schema',
505
+ json_schema=dict(
506
+ schema=json_schema,
507
+ name=json_schema['title'],
508
+ strict=True,
622
509
  )
623
- ),
510
+ )
511
+ )
512
+ prompt.metadata.formatted_text = (
513
+ prompt.text
514
+ + '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
515
+ + pg.to_json_str(request_args['response_format'], json_indent=2)
516
+ )
517
+
518
+ # Prepare messages.
519
+ messages = []
520
+ # Users could use `metadata_system_message` to pass system message.
521
+ system_message = prompt.metadata.get('system_message')
522
+ if system_message:
523
+ system_message = lf.SystemMessage.from_value(system_message)
524
+ messages.append(
525
+ dict(role='system',
526
+ content=self._content_from_message(system_message))
624
527
  )
528
+ messages.append(
529
+ dict(role='user', content=self._content_from_message(prompt))
530
+ )
531
+ request = dict()
532
+ request.update(request_args)
533
+ request['messages'] = messages
534
+ return request
625
535
 
626
- return self._parallel_execute_with_currency_control(
627
- _open_ai_chat_completion,
628
- prompts,
629
- retry_on_errors=(
630
- ServiceUnavailableError,
631
- RateLimitError,
632
- APITimeoutError
536
+ def _parse_choice(self, choice: dict[str, Any]) -> lf.LMSample:
537
+ # Reference:
538
+ # https://platform.openai.com/docs/api-reference/chat/object
539
+ logprobs = None
540
+ choice_logprobs = choice.get('logprobs')
541
+ if choice_logprobs:
542
+ logprobs = [
543
+ (
544
+ t['token'],
545
+ t['logprob'],
546
+ [(tt['token'], tt['logprob']) for tt in t['top_logprobs']],
547
+ )
548
+ for t in choice_logprobs['content']
549
+ ]
550
+ return lf.LMSample(
551
+ choice['message']['content'],
552
+ score=0.0,
553
+ logprobs=logprobs,
554
+ )
555
+
556
+ def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
557
+ usage = json['usage']
558
+ return lf.LMSamplingResult(
559
+ samples=[self._parse_choice(choice) for choice in json['choices']],
560
+ usage=lf.LMSamplingUsage(
561
+ prompt_tokens=usage['prompt_tokens'],
562
+ completion_tokens=usage['completion_tokens'],
563
+ total_tokens=usage['total_tokens'],
564
+ estimated_cost=self.estimate_cost(
565
+ num_input_tokens=usage['prompt_tokens'],
566
+ num_output_tokens=usage['completion_tokens'],
567
+ )
633
568
  ),
634
569
  )
635
570
 
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  """Tests for OpenAI models."""
15
15
 
16
+ from typing import Any
16
17
  import unittest
17
18
  from unittest import mock
18
19
 
@@ -20,86 +21,106 @@ import langfun.core as lf
20
21
  from langfun.core import modalities as lf_modalities
21
22
  from langfun.core.llms import openai
22
23
  import pyglove as pg
24
+ import requests
23
25
 
24
26
 
25
- def mock_completion_query(prompt, *, n=1, **kwargs):
26
- del kwargs
27
- choices = []
28
- for i, _ in enumerate(prompt):
29
- for k in range(n):
30
- choices.append(pg.Dict(
31
- index=i,
32
- text=f'Sample {k} for prompt {i}.',
33
- logprobs=k / 10,
34
- ))
35
- return pg.Dict(
36
- choices=choices,
37
- usage=lf.LMSamplingUsage(
38
- prompt_tokens=100,
39
- completion_tokens=100,
40
- total_tokens=200,
41
- ),
42
- )
43
-
44
-
45
- def mock_chat_completion_query(messages, *, n=1, **kwargs):
27
+ def mock_chat_completion_request(url: str, json: dict[str, Any], **kwargs):
28
+ del url, kwargs
29
+ messages = json['messages']
46
30
  if len(messages) > 1:
47
31
  system_message = f' system={messages[0]["content"]}'
48
32
  else:
49
33
  system_message = ''
50
34
 
51
- if 'response_format' in kwargs:
52
- response_format = f' format={kwargs["response_format"]["type"]}'
35
+ if 'response_format' in json:
36
+ response_format = f' format={json["response_format"]["type"]}'
53
37
  else:
54
38
  response_format = ''
55
39
 
56
40
  choices = []
57
- for k in range(n):
58
- choices.append(pg.Dict(
59
- message=pg.Dict(
41
+ for k in range(json['n']):
42
+ if json.get('logprobs'):
43
+ logprobs = dict(
44
+ content=[
45
+ dict(
46
+ token='chosen_token',
47
+ logprob=0.5,
48
+ top_logprobs=[
49
+ dict(
50
+ token=f'alternative_token_{i + 1}',
51
+ logprob=0.1
52
+ ) for i in range(3)
53
+ ]
54
+ )
55
+ ]
56
+ )
57
+ else:
58
+ logprobs = None
59
+
60
+ choices.append(dict(
61
+ message=dict(
60
62
  content=(
61
63
  f'Sample {k} for message.{system_message}{response_format}'
62
64
  )
63
65
  ),
64
- logprobs=None,
66
+ logprobs=logprobs,
65
67
  ))
66
- return pg.Dict(
67
- choices=choices,
68
- usage=lf.LMSamplingUsage(
69
- prompt_tokens=100,
70
- completion_tokens=100,
71
- total_tokens=200,
72
- ),
73
- )
68
+ response = requests.Response()
69
+ response.status_code = 200
70
+ response._content = pg.to_json_str(
71
+ dict(
72
+ choices=choices,
73
+ usage=lf.LMSamplingUsage(
74
+ prompt_tokens=100,
75
+ completion_tokens=100,
76
+ total_tokens=200,
77
+ ),
78
+ )
79
+ ).encode()
80
+ return response
74
81
 
75
82
 
76
- def mock_chat_completion_query_vision(messages, *, n=1, **kwargs):
77
- del kwargs
83
+ def mock_chat_completion_request_vision(
84
+ url: str, json: dict[str, Any], **kwargs
85
+ ):
86
+ del url, kwargs
78
87
  choices = []
79
88
  urls = [
80
89
  c['image_url']['url']
81
- for c in messages[0]['content'] if c['type'] == 'image_url'
90
+ for c in json['messages'][0]['content'] if c['type'] == 'image_url'
82
91
  ]
83
- for k in range(n):
92
+ for k in range(json['n']):
84
93
  choices.append(pg.Dict(
85
94
  message=pg.Dict(
86
95
  content=f'Sample {k} for message: {"".join(urls)}'
87
96
  ),
88
97
  logprobs=None,
89
98
  ))
90
- return pg.Dict(
91
- choices=choices,
92
- usage=lf.LMSamplingUsage(
93
- prompt_tokens=100,
94
- completion_tokens=100,
95
- total_tokens=200,
96
- ),
97
- )
99
+ response = requests.Response()
100
+ response.status_code = 200
101
+ response._content = pg.to_json_str(
102
+ dict(
103
+ choices=choices,
104
+ usage=lf.LMSamplingUsage(
105
+ prompt_tokens=100,
106
+ completion_tokens=100,
107
+ total_tokens=200,
108
+ ),
109
+ )
110
+ ).encode()
111
+ return response
98
112
 
99
113
 
100
114
  class OpenAITest(unittest.TestCase):
101
115
  """Tests for OpenAI language model."""
102
116
 
117
+ def test_dir(self):
118
+ self.assertIn('gpt-4-turbo', openai.OpenAI.dir())
119
+
120
+ def test_key(self):
121
+ with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
122
+ openai.Gpt4()('hi')
123
+
103
124
  def test_model_id(self):
104
125
  self.assertEqual(
105
126
  openai.Gpt35(api_key='test_key').model_id, 'OpenAI(text-davinci-003)')
@@ -112,29 +133,9 @@ class OpenAITest(unittest.TestCase):
112
133
  def test_max_concurrency(self):
113
134
  self.assertGreater(openai.Gpt35(api_key='test_key').max_concurrency, 0)
114
135
 
115
- def test_get_request_args(self):
116
- self.assertEqual(
117
- openai.Gpt35(api_key='test_key', timeout=90.0)._get_request_args(
118
- lf.LMSamplingOptions(
119
- temperature=2.0,
120
- logprobs=True,
121
- n=2,
122
- max_tokens=4096,
123
- top_p=1.0)),
124
- dict(
125
- engine='text-davinci-003',
126
- logprobs=True,
127
- top_logprobs=None,
128
- n=2,
129
- temperature=2.0,
130
- max_tokens=4096,
131
- stream=False,
132
- timeout=90.0,
133
- top_p=1.0,
134
- )
135
- )
136
+ def test_request_args(self):
136
137
  self.assertEqual(
137
- openai.Gpt4(api_key='test_key')._get_request_args(
138
+ openai.Gpt4(api_key='test_key')._request_args(
138
139
  lf.LMSamplingOptions(
139
140
  temperature=1.0, stop=['\n'], n=1, random_seed=123
140
141
  )
@@ -144,40 +145,93 @@ class OpenAITest(unittest.TestCase):
144
145
  top_logprobs=None,
145
146
  n=1,
146
147
  temperature=1.0,
147
- stream=False,
148
- timeout=120.0,
149
148
  stop=['\n'],
150
149
  seed=123,
151
150
  ),
152
151
  )
153
152
  with self.assertRaisesRegex(RuntimeError, '`logprobs` is not supported.*'):
154
- openai.GptO1Preview(api_key='test_key')._get_request_args(
153
+ openai.GptO1Preview(api_key='test_key')._request_args(
155
154
  lf.LMSamplingOptions(
156
155
  temperature=1.0, logprobs=True
157
156
  )
158
157
  )
159
158
 
160
- def test_call_completion(self):
161
- with mock.patch('openai.Completion.create') as mock_completion:
162
- mock_completion.side_effect = mock_completion_query
163
- lm = openai.OpenAI(api_key='test_key', model='text-davinci-003')
159
+ def test_call_chat_completion(self):
160
+ with mock.patch('requests.Session.post') as mock_request:
161
+ mock_request.side_effect = mock_chat_completion_request
162
+ lm = openai.OpenAI(
163
+ model='gpt-4',
164
+ api_key='test_key',
165
+ organization='my_org',
166
+ project='my_project'
167
+ )
164
168
  self.assertEqual(
165
169
  lm('hello', sampling_options=lf.LMSamplingOptions(n=2)),
166
- 'Sample 0 for prompt 0.',
170
+ 'Sample 0 for message.',
167
171
  )
168
172
 
169
- def test_call_chat_completion(self):
170
- with mock.patch('openai.ChatCompletion.create') as mock_chat_completion:
171
- mock_chat_completion.side_effect = mock_chat_completion_query
172
- lm = openai.OpenAI(api_key='test_key', model='gpt-4')
173
+ def test_call_chat_completion_with_logprobs(self):
174
+ with mock.patch('requests.Session.post') as mock_request:
175
+ mock_request.side_effect = mock_chat_completion_request
176
+ lm = openai.OpenAI(
177
+ model='gpt-4',
178
+ api_key='test_key',
179
+ organization='my_org',
180
+ project='my_project'
181
+ )
182
+ results = lm.sample(['hello'], logprobs=True)
183
+ self.assertEqual(len(results), 1)
173
184
  self.assertEqual(
174
- lm('hello', sampling_options=lf.LMSamplingOptions(n=2)),
175
- 'Sample 0 for message.',
185
+ results[0],
186
+ lf.LMSamplingResult(
187
+ [
188
+ lf.LMSample(
189
+ response=lf.AIMessage(
190
+ text='Sample 0 for message.',
191
+ metadata={
192
+ 'score': 0.0,
193
+ 'logprobs': [(
194
+ 'chosen_token',
195
+ 0.5,
196
+ [
197
+ ('alternative_token_1', 0.1),
198
+ ('alternative_token_2', 0.1),
199
+ ('alternative_token_3', 0.1),
200
+ ],
201
+ )],
202
+ 'is_cached': False,
203
+ 'usage': lf.LMSamplingUsage(
204
+ prompt_tokens=100,
205
+ completion_tokens=100,
206
+ total_tokens=200,
207
+ estimated_cost=0.009,
208
+ ),
209
+ },
210
+ tags=['lm-response'],
211
+ ),
212
+ logprobs=[(
213
+ 'chosen_token',
214
+ 0.5,
215
+ [
216
+ ('alternative_token_1', 0.1),
217
+ ('alternative_token_2', 0.1),
218
+ ('alternative_token_3', 0.1),
219
+ ],
220
+ )],
221
+ )
222
+ ],
223
+ usage=lf.LMSamplingUsage(
224
+ prompt_tokens=100,
225
+ completion_tokens=100,
226
+ total_tokens=200,
227
+ estimated_cost=0.009,
228
+ ),
229
+ ),
176
230
  )
177
231
 
178
232
  def test_call_chat_completion_vision(self):
179
- with mock.patch('openai.ChatCompletion.create') as mock_chat_completion:
180
- mock_chat_completion.side_effect = mock_chat_completion_query_vision
233
+ with mock.patch('requests.Session.post') as mock_request:
234
+ mock_request.side_effect = mock_chat_completion_request_vision
181
235
  lm_1 = openai.Gpt4Turbo(api_key='test_key')
182
236
  lm_2 = openai.Gpt4VisionPreview(api_key='test_key')
183
237
  for lm in (lm_1, lm_2):
@@ -191,136 +245,18 @@ class OpenAITest(unittest.TestCase):
191
245
  ),
192
246
  'Sample 0 for message: https://fake/image',
193
247
  )
194
-
195
- def test_sample_completion(self):
196
- with mock.patch('openai.Completion.create') as mock_completion:
197
- mock_completion.side_effect = mock_completion_query
198
- lm = openai.OpenAI(api_key='test_key', model='text-davinci-003')
199
- results = lm.sample(
200
- ['hello', 'bye'], sampling_options=lf.LMSamplingOptions(n=3)
248
+ lm_3 = openai.Gpt35Turbo(api_key='test_key')
249
+ with self.assertRaisesRegex(ValueError, 'Unsupported modality'):
250
+ lm_3(
251
+ lf.UserMessage(
252
+ 'hello <<[[image]]>>',
253
+ image=lf_modalities.Image.from_uri('https://fake/image')
254
+ ),
201
255
  )
202
256
 
203
- self.assertEqual(len(results), 2)
204
- self.assertEqual(
205
- results[0],
206
- lf.LMSamplingResult(
207
- [
208
- lf.LMSample(
209
- lf.AIMessage(
210
- 'Sample 0 for prompt 0.',
211
- score=0.0,
212
- logprobs=None,
213
- is_cached=False,
214
- usage=lf.LMSamplingUsage(
215
- prompt_tokens=16,
216
- completion_tokens=16,
217
- total_tokens=33
218
- ),
219
- tags=[lf.Message.TAG_LM_RESPONSE],
220
- ),
221
- score=0.0,
222
- logprobs=None,
223
- ),
224
- lf.LMSample(
225
- lf.AIMessage(
226
- 'Sample 1 for prompt 0.',
227
- score=0.1,
228
- logprobs=None,
229
- is_cached=False,
230
- usage=lf.LMSamplingUsage(
231
- prompt_tokens=16,
232
- completion_tokens=16,
233
- total_tokens=33
234
- ),
235
- tags=[lf.Message.TAG_LM_RESPONSE],
236
- ),
237
- score=0.1,
238
- logprobs=None,
239
- ),
240
- lf.LMSample(
241
- lf.AIMessage(
242
- 'Sample 2 for prompt 0.',
243
- score=0.2,
244
- logprobs=None,
245
- is_cached=False,
246
- usage=lf.LMSamplingUsage(
247
- prompt_tokens=16,
248
- completion_tokens=16,
249
- total_tokens=33
250
- ),
251
- tags=[lf.Message.TAG_LM_RESPONSE],
252
- ),
253
- score=0.2,
254
- logprobs=None,
255
- ),
256
- ],
257
- usage=lf.LMSamplingUsage(
258
- prompt_tokens=50, completion_tokens=50, total_tokens=100
259
- ),
260
- ),
261
- )
262
- self.assertEqual(
263
- results[1],
264
- lf.LMSamplingResult(
265
- [
266
- lf.LMSample(
267
- lf.AIMessage(
268
- 'Sample 0 for prompt 1.',
269
- score=0.0,
270
- logprobs=None,
271
- is_cached=False,
272
- usage=lf.LMSamplingUsage(
273
- prompt_tokens=16,
274
- completion_tokens=16,
275
- total_tokens=33
276
- ),
277
- tags=[lf.Message.TAG_LM_RESPONSE],
278
- ),
279
- score=0.0,
280
- logprobs=None,
281
- ),
282
- lf.LMSample(
283
- lf.AIMessage(
284
- 'Sample 1 for prompt 1.',
285
- score=0.1,
286
- logprobs=None,
287
- is_cached=False,
288
- usage=lf.LMSamplingUsage(
289
- prompt_tokens=16,
290
- completion_tokens=16,
291
- total_tokens=33
292
- ),
293
- tags=[lf.Message.TAG_LM_RESPONSE],
294
- ),
295
- score=0.1,
296
- logprobs=None,
297
- ),
298
- lf.LMSample(
299
- lf.AIMessage(
300
- 'Sample 2 for prompt 1.',
301
- score=0.2,
302
- logprobs=None,
303
- is_cached=False,
304
- usage=lf.LMSamplingUsage(
305
- prompt_tokens=16,
306
- completion_tokens=16,
307
- total_tokens=33
308
- ),
309
- tags=[lf.Message.TAG_LM_RESPONSE],
310
- ),
311
- score=0.2,
312
- logprobs=None,
313
- ),
314
- ],
315
- usage=lf.LMSamplingUsage(
316
- prompt_tokens=50, completion_tokens=50, total_tokens=100
317
- ),
318
- ),
319
- )
320
-
321
257
  def test_sample_chat_completion(self):
322
- with mock.patch('openai.ChatCompletion.create') as mock_chat_completion:
323
- mock_chat_completion.side_effect = mock_chat_completion_query
258
+ with mock.patch('requests.Session.post') as mock_request:
259
+ mock_request.side_effect = mock_chat_completion_request
324
260
  openai.SUPPORTED_MODELS_AND_SETTINGS['gpt-4'].update({
325
261
  'cost_per_1k_input_tokens': 1.0,
326
262
  'cost_per_1k_output_tokens': 1.0,
@@ -458,8 +394,8 @@ class OpenAITest(unittest.TestCase):
458
394
  )
459
395
 
460
396
  def test_sample_with_contextual_options(self):
461
- with mock.patch('openai.Completion.create') as mock_completion:
462
- mock_completion.side_effect = mock_completion_query
397
+ with mock.patch('requests.Session.post') as mock_request:
398
+ mock_request.side_effect = mock_chat_completion_request
463
399
  lm = openai.OpenAI(api_key='test_key', model='text-davinci-003')
464
400
  with lf.use_settings(sampling_options=lf.LMSamplingOptions(n=2)):
465
401
  results = lm.sample(['hello'])
@@ -471,7 +407,7 @@ class OpenAITest(unittest.TestCase):
471
407
  [
472
408
  lf.LMSample(
473
409
  lf.AIMessage(
474
- 'Sample 0 for prompt 0.',
410
+ 'Sample 0 for message.',
475
411
  score=0.0,
476
412
  logprobs=None,
477
413
  is_cached=False,
@@ -487,8 +423,8 @@ class OpenAITest(unittest.TestCase):
487
423
  ),
488
424
  lf.LMSample(
489
425
  lf.AIMessage(
490
- 'Sample 1 for prompt 0.',
491
- score=0.1,
426
+ 'Sample 1 for message.',
427
+ score=0.0,
492
428
  logprobs=None,
493
429
  is_cached=False,
494
430
  usage=lf.LMSamplingUsage(
@@ -498,19 +434,19 @@ class OpenAITest(unittest.TestCase):
498
434
  ),
499
435
  tags=[lf.Message.TAG_LM_RESPONSE],
500
436
  ),
501
- score=0.1,
437
+ score=0.0,
502
438
  logprobs=None,
503
439
  ),
504
440
  ],
505
441
  usage=lf.LMSamplingUsage(
506
442
  prompt_tokens=100, completion_tokens=100, total_tokens=200
507
443
  ),
508
- ),
444
+ )
509
445
  )
510
446
 
511
447
  def test_call_with_system_message(self):
512
- with mock.patch('openai.ChatCompletion.create') as mock_chat_completion:
513
- mock_chat_completion.side_effect = mock_chat_completion_query
448
+ with mock.patch('requests.Session.post') as mock_request:
449
+ mock_request.side_effect = mock_chat_completion_request
514
450
  lm = openai.OpenAI(api_key='test_key', model='gpt-4')
515
451
  self.assertEqual(
516
452
  lm(
@@ -520,12 +456,12 @@ class OpenAITest(unittest.TestCase):
520
456
  ),
521
457
  sampling_options=lf.LMSamplingOptions(n=2)
522
458
  ),
523
- 'Sample 0 for message. system=hi',
459
+ '''Sample 0 for message. system=[{'type': 'text', 'text': 'hi'}]''',
524
460
  )
525
461
 
526
462
  def test_call_with_json_schema(self):
527
- with mock.patch('openai.ChatCompletion.create') as mock_chat_completion:
528
- mock_chat_completion.side_effect = mock_chat_completion_query
463
+ with mock.patch('requests.Session.post') as mock_request:
464
+ mock_request.side_effect = mock_chat_completion_request
529
465
  lm = openai.OpenAI(api_key='test_key', model='gpt-4')
530
466
  self.assertEqual(
531
467
  lm(
@@ -426,6 +426,7 @@ class VertexRestfulAITest(unittest.TestCase):
426
426
  model,
427
427
  )
428
428
 
429
+ @mock.patch.object(vertexai.VertexAIRest, 'credentials', new=True)
429
430
  def test_project_and_location_check(self):
430
431
  with self.assertRaisesRegex(ValueError, 'Please specify `project`'):
431
432
  _ = vertexai.VertexAIGeminiPro1()._api_initialized
@@ -496,6 +497,7 @@ class VertexRestfulAITest(unittest.TestCase):
496
497
  lf.LMSamplingOptions(),
497
498
  )
498
499
 
500
+ @mock.patch.object(vertexai.VertexAIRest, 'credentials', new=True)
499
501
  def test_call_model(self):
500
502
  with mock.patch('requests.Session.post') as mock_generate:
501
503
  mock_generate.side_effect = mock_requests_post
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.1.2.dev202412020805
3
+ Version: 0.1.2.dev202412030000
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -32,7 +32,6 @@ Requires-Dist: termcolor==1.1.0; extra == "all"
32
32
  Requires-Dist: tqdm>=4.64.1; extra == "all"
33
33
  Requires-Dist: google-cloud-aiplatform>=1.5.0; extra == "all"
34
34
  Requires-Dist: google-generativeai>=0.3.2; extra == "all"
35
- Requires-Dist: openai>=0.27.2; extra == "all"
36
35
  Requires-Dist: python-magic>=0.4.27; extra == "all"
37
36
  Requires-Dist: python-docx>=0.8.11; extra == "all"
38
37
  Requires-Dist: pillow>=10.0.0; extra == "all"
@@ -44,7 +43,6 @@ Requires-Dist: tqdm>=4.64.1; extra == "ui"
44
43
  Provides-Extra: llm
45
44
  Requires-Dist: google-cloud-aiplatform>=1.5.0; extra == "llm"
46
45
  Requires-Dist: google-generativeai>=0.3.2; extra == "llm"
47
- Requires-Dist: openai>=0.27.2; extra == "llm"
48
46
  Provides-Extra: llm-google
49
47
  Requires-Dist: google-cloud-aiplatform>=1.5.0; extra == "llm-google"
50
48
  Requires-Dist: google-generativeai>=0.3.2; extra == "llm-google"
@@ -52,8 +50,6 @@ Provides-Extra: llm-google-vertex
52
50
  Requires-Dist: google-cloud-aiplatform>=1.5.0; extra == "llm-google-vertex"
53
51
  Provides-Extra: llm-google-genai
54
52
  Requires-Dist: google-generativeai>=0.3.2; extra == "llm-google-genai"
55
- Provides-Extra: llm-openai
56
- Requires-Dist: openai>=0.27.2; extra == "llm-openai"
57
53
  Provides-Extra: mime
58
54
  Requires-Dist: python-magic>=0.4.27; extra == "mime"
59
55
  Requires-Dist: python-docx>=0.8.11; extra == "mime"
@@ -214,7 +210,6 @@ If you want to customize your installation, you can select specific features usi
214
210
  | llm-google | All supported Google-powered LLMs. |
215
211
  | llm-google-vertexai | LLMs powered by Google Cloud VertexAI |
216
212
  | llm-google-genai | LLMs powered by Google Generative AI API |
217
- | llm-openai | LLMs powered by OpenAI |
218
213
  | mime | All MIME supports. |
219
214
  | mime-auto | Automatic MIME type detection. |
220
215
  | mime-docx | DocX format support. |
@@ -92,12 +92,12 @@ langfun/core/llms/groq.py,sha256=dCnR3eAECEKuKKAAj-PDTs8NRHl6CQPdf57m1f6a79U,103
92
92
  langfun/core/llms/groq_test.py,sha256=GYF_Qtq5S1H1TrKH38t6_lkdroqT7v-joYLDKnmS9e0,5274
93
93
  langfun/core/llms/llama_cpp.py,sha256=9tXQntSCDtjTF3bnyJrAPCr4N6wycy5nXYvp9uduygE,2843
94
94
  langfun/core/llms/llama_cpp_test.py,sha256=MWO_qaOeKjRniGjcaWPDScd7HPaIJemqUZoslrt4FPs,1806
95
- langfun/core/llms/openai.py,sha256=_VwOSuDsyXDngUM2iiES0CW1aN0BzMjXNBMegLzm4J4,23209
96
- langfun/core/llms/openai_test.py,sha256=_8cd3VRNEUfE0-Ko1RiM6MlC5hjalRj7nYTJNhG1p3E,18907
95
+ langfun/core/llms/openai.py,sha256=l49v6RubfInvV0iG114AymTKNogTX4u4N-UFCeSgIxw,20963
96
+ langfun/core/llms/openai_test.py,sha256=kOWa1nf-nJvtYY10REUw5wojh3ZgfU8tRaCZ8wUgJbA,16623
97
97
  langfun/core/llms/rest.py,sha256=sWbYUV8S3SuOg9giq7xwD-xDRfaF7NP_ig7bI52-Rj4,3442
98
98
  langfun/core/llms/rest_test.py,sha256=NZ3Nf0XQVpT9kLP5cBVo_yBHLI7vWTYhWQxYEJVMGs4,3472
99
99
  langfun/core/llms/vertexai.py,sha256=EZhJrdN-SsZVV0KT3NHzaJLVKsNMxCT6M3W6f5fpIWQ,27068
100
- langfun/core/llms/vertexai_test.py,sha256=nGv59yE4xu1zUxqmP_U941QjSBrr_sW15Q2YakuxMv4,16982
100
+ langfun/core/llms/vertexai_test.py,sha256=qapDa7fvLkHm3BhG12a-HopxGCn625r-eVud2QqRITo,17120
101
101
  langfun/core/llms/cache/__init__.py,sha256=QAo3InUMDM_YpteNnVCSejI4zOsnjSMWKJKzkb3VY64,993
102
102
  langfun/core/llms/cache/base.py,sha256=rt3zwmyw0y9jsSGW-ZbV1vAfLxQ7_3AVk0l2EySlse4,3918
103
103
  langfun/core/llms/cache/in_memory.py,sha256=l6b-iU9OTfTRo9Zmg4VrQIuArs4cCJDOpXiEpvNocjo,5004
@@ -148,8 +148,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
148
148
  langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
149
149
  langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
150
150
  langfun/core/templates/selfplay_test.py,sha256=Ot__1P1M8oJfoTp-M9-PQ6HUXqZKyMwvZ5f7yQ3yfyM,2326
151
- langfun-0.1.2.dev202412020805.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
152
- langfun-0.1.2.dev202412020805.dist-info/METADATA,sha256=c3yjg186RyrDaIHGLMpmXsI7-Kqj4V1vLGxYsjJJN2Y,8890
153
- langfun-0.1.2.dev202412020805.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
154
- langfun-0.1.2.dev202412020805.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
155
- langfun-0.1.2.dev202412020805.dist-info/RECORD,,
151
+ langfun-0.1.2.dev202412030000.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
152
+ langfun-0.1.2.dev202412030000.dist-info/METADATA,sha256=PoROaIMontFjWm5sPVdo3DpJATWFoFbO8IOr9t-3K2o,8651
153
+ langfun-0.1.2.dev202412030000.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
154
+ langfun-0.1.2.dev202412030000.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
155
+ langfun-0.1.2.dev202412030000.dist-info/RECORD,,