@posthog/ai 5.2.0 → 5.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,10 +3,11 @@ import { withPrivacyMode, getModelParams } from '../utils'
3
3
  import { BaseCallbackHandler } from '@langchain/core/callbacks/base'
4
4
  import type { Serialized } from '@langchain/core/load/serializable'
5
5
  import type { ChainValues } from '@langchain/core/utils/types'
6
- import type { BaseMessage } from '@langchain/core/messages'
7
6
  import type { LLMResult } from '@langchain/core/outputs'
8
7
  import type { AgentAction, AgentFinish } from '@langchain/core/agents'
9
8
  import type { DocumentInterface } from '@langchain/core/documents'
9
+ import { ToolCall } from '@langchain/core/messages/tool'
10
+ import { BaseMessage } from '@langchain/core/messages'
10
11
 
11
12
  interface SpanMetadata {
12
13
  /** Name of the trace/span (e.g. chain name) */
@@ -498,11 +499,21 @@ export class LangChainCallbackHandler extends BaseCallbackHandler {
498
499
  return undefined
499
500
  }
500
501
 
502
+ private _convertLcToolCallsToOai(toolCalls: ToolCall[]): Record<string, any>[] {
503
+ return toolCalls.map((toolCall: ToolCall) => ({
504
+ type: 'function',
505
+ id: toolCall.id,
506
+ function: {
507
+ name: toolCall.name,
508
+ arguments: JSON.stringify(toolCall.args),
509
+ },
510
+ }))
511
+ }
512
+
501
513
  private _convertMessageToDict(message: any): Record<string, any> {
502
514
  let messageDict: Record<string, any> = {}
503
515
 
504
- // Check the _getType() method or type property instead of instanceof
505
- const messageType = message._getType?.() || message.type
516
+ const messageType: string = message.getType()
506
517
 
507
518
  switch (messageType) {
508
519
  case 'human':
@@ -510,6 +521,11 @@ export class LangChainCallbackHandler extends BaseCallbackHandler {
510
521
  break
511
522
  case 'ai':
512
523
  messageDict = { role: 'assistant', content: message.content }
524
+
525
+ if (message.tool_calls) {
526
+ messageDict.tool_calls = this._convertLcToolCallsToOai(message.tool_calls)
527
+ }
528
+
513
529
  break
514
530
  case 'system':
515
531
  messageDict = { role: 'system', content: message.content }
@@ -521,12 +537,14 @@ export class LangChainCallbackHandler extends BaseCallbackHandler {
521
537
  messageDict = { role: 'function', content: message.content }
522
538
  break
523
539
  default:
524
- messageDict = { role: messageType || 'unknown', content: String(message.content) }
540
+ messageDict = { role: messageType, content: String(message.content) }
541
+ break
525
542
  }
526
543
 
527
544
  if (message.additional_kwargs) {
528
545
  messageDict = { ...messageDict, ...message.additional_kwargs }
529
546
  }
547
+
530
548
  return messageDict
531
549
  }
532
550
 
@@ -1,4 +1,4 @@
1
- import OpenAIOrignal, { ClientOptions } from 'openai'
1
+ import { OpenAI as OpenAIOrignal, ClientOptions } from 'openai'
2
2
  import { PostHog } from 'posthog-node'
3
3
  import { v4 as uuidv4 } from 'uuid'
4
4
  import { formatResponseOpenAI, MonitoringParams, sendEventToPosthog } from '../utils'
@@ -6,6 +6,10 @@ import type { APIPromise } from 'openai'
6
6
  import type { Stream } from 'openai/streaming'
7
7
  import type { ParsedResponse } from 'openai/resources/responses/responses'
8
8
 
9
+ const Chat = OpenAIOrignal.Chat
10
+ const Completions = Chat.Completions
11
+ const Responses = OpenAIOrignal.Responses
12
+
9
13
  type ChatCompletion = OpenAIOrignal.ChatCompletion
10
14
  type ChatCompletionChunk = OpenAIOrignal.ChatCompletionChunk
11
15
  type ChatCompletionCreateParamsBase = OpenAIOrignal.Chat.Completions.ChatCompletionCreateParams
@@ -37,7 +41,7 @@ export class PostHogOpenAI extends OpenAIOrignal {
37
41
  }
38
42
  }
39
43
 
40
- export class WrappedChat extends OpenAIOrignal.Chat {
44
+ export class WrappedChat extends Chat {
41
45
  constructor(parentClient: PostHogOpenAI, phClient: PostHog) {
42
46
  super(parentClient)
43
47
  this.completions = new WrappedCompletions(parentClient, phClient)
@@ -46,7 +50,7 @@ export class WrappedChat extends OpenAIOrignal.Chat {
46
50
  public completions: WrappedCompletions
47
51
  }
48
52
 
49
- export class WrappedCompletions extends OpenAIOrignal.Chat.Completions {
53
+ export class WrappedCompletions extends Completions {
50
54
  private readonly phClient: PostHog
51
55
 
52
56
  constructor(client: OpenAIOrignal, phClient: PostHog) {
@@ -223,7 +227,7 @@ export class WrappedCompletions extends OpenAIOrignal.Chat.Completions {
223
227
  }
224
228
  }
225
229
 
226
- export class WrappedResponses extends OpenAIOrignal.Responses {
230
+ export class WrappedResponses extends Responses {
227
231
  private readonly phClient: PostHog
228
232
 
229
233
  constructor(client: OpenAIOrignal, phClient: PostHog) {
@@ -0,0 +1,48 @@
1
+ import { LangChainCallbackHandler } from '../src/langchain/callbacks'
2
+ import { PostHog } from 'posthog-node'
3
+ import { AIMessage } from '@langchain/core/messages'
4
+
5
+ const mockPostHogClient = {
6
+ capture: jest.fn(),
7
+ } as unknown as PostHog
8
+
9
+ describe('LangChainCallbackHandler', () => {
10
+ let handler: LangChainCallbackHandler
11
+
12
+ beforeEach(() => {
13
+ handler = new LangChainCallbackHandler({
14
+ client: mockPostHogClient,
15
+ })
16
+ jest.clearAllMocks()
17
+ })
18
+
19
+ it('should convert AIMessage with tool calls to dict format', () => {
20
+ const toolCalls = [
21
+ {
22
+ id: 'call_123',
23
+ name: 'get_weather',
24
+ args: { city: 'San Francisco', units: 'celsius' },
25
+ },
26
+ ]
27
+
28
+ const aiMessage = new AIMessage({
29
+ content: "I'll check the weather for you.",
30
+ tool_calls: toolCalls,
31
+ })
32
+
33
+ const result = (handler as any)._convertMessageToDict(aiMessage)
34
+
35
+ expect(result.role).toBe('assistant')
36
+ expect(result.content).toBe("I'll check the weather for you.")
37
+ expect(result.tool_calls).toEqual([
38
+ {
39
+ type: 'function',
40
+ id: 'call_123',
41
+ function: {
42
+ name: 'get_weather',
43
+ arguments: '{"city":"San Francisco","units":"celsius"}',
44
+ },
45
+ },
46
+ ])
47
+ })
48
+ })