posthog 7.0.1__py3-none-any.whl → 7.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,612 @@
1
+ """
2
+ Tests for FlagDefinitionCacheProvider functionality.
3
+
4
+ These tests follow the patterns from the TypeScript implementation in posthog-js/packages/node.
5
+ """
6
+
7
+ import threading
8
+ import unittest
9
+ from typing import Optional
10
+ from unittest import mock
11
+
12
+ from posthog.client import Client
13
+ from posthog.flag_definition_cache import (
14
+ FlagDefinitionCacheData,
15
+ FlagDefinitionCacheProvider,
16
+ )
17
+ from posthog.request import GetResponse
18
+ from posthog.test.test_utils import FAKE_TEST_API_KEY
19
+
20
+
21
+ class MockCacheProvider:
22
+ """A mock implementation of FlagDefinitionCacheProvider for testing."""
23
+
24
+ def __init__(self):
25
+ self.stored_data: Optional[FlagDefinitionCacheData] = None
26
+ self.should_fetch_return_value = True
27
+ self.get_call_count = 0
28
+ self.should_fetch_call_count = 0
29
+ self.on_received_call_count = 0
30
+ self.shutdown_call_count = 0
31
+ self.should_fetch_error: Optional[Exception] = None
32
+ self.get_error: Optional[Exception] = None
33
+ self.on_received_error: Optional[Exception] = None
34
+ self.shutdown_error: Optional[Exception] = None
35
+
36
+ def get_flag_definitions(self) -> Optional[FlagDefinitionCacheData]:
37
+ self.get_call_count += 1
38
+ if self.get_error:
39
+ raise self.get_error
40
+ return self.stored_data
41
+
42
+ def should_fetch_flag_definitions(self) -> bool:
43
+ self.should_fetch_call_count += 1
44
+ if self.should_fetch_error:
45
+ raise self.should_fetch_error
46
+ return self.should_fetch_return_value
47
+
48
+ def on_flag_definitions_received(self, data: FlagDefinitionCacheData) -> None:
49
+ self.on_received_call_count += 1
50
+ if self.on_received_error:
51
+ raise self.on_received_error
52
+ self.stored_data = data
53
+
54
+ def shutdown(self) -> None:
55
+ self.shutdown_call_count += 1
56
+ if self.shutdown_error:
57
+ raise self.shutdown_error
58
+
59
+
60
+ class TestFlagDefinitionCacheProvider(unittest.TestCase):
61
+ """Tests for the FlagDefinitionCacheProvider protocol."""
62
+
63
+ @classmethod
64
+ def setUpClass(cls):
65
+ # Prevent real HTTP requests
66
+ cls.client_post_patcher = mock.patch("posthog.client.batch_post")
67
+ cls.consumer_post_patcher = mock.patch("posthog.consumer.batch_post")
68
+ cls.client_post_patcher.start()
69
+ cls.consumer_post_patcher.start()
70
+
71
+ @classmethod
72
+ def tearDownClass(cls):
73
+ cls.client_post_patcher.stop()
74
+ cls.consumer_post_patcher.stop()
75
+
76
+ def setUp(self):
77
+ self.cache_provider = MockCacheProvider()
78
+ self.sample_flags_data: FlagDefinitionCacheData = {
79
+ "flags": [
80
+ {"key": "test-flag", "active": True, "filters": {}},
81
+ {"key": "another-flag", "active": False, "filters": {}},
82
+ ],
83
+ "group_type_mapping": {"0": "company", "1": "project"},
84
+ "cohorts": {"1": {"properties": []}},
85
+ }
86
+
87
+ def tearDown(self):
88
+ # Ensure client cleanup
89
+ pass
90
+
91
+ def _create_client_with_cache(self) -> Client:
92
+ """Create a client with the mock cache provider."""
93
+ return Client(
94
+ FAKE_TEST_API_KEY,
95
+ personal_api_key="test-personal-key",
96
+ flag_definition_cache_provider=self.cache_provider,
97
+ sync_mode=True,
98
+ enable_local_evaluation=False, # Disable poller for tests
99
+ )
100
+
101
+
102
+ class TestCacheInitialization(TestFlagDefinitionCacheProvider):
103
+ """Tests for cache initialization behavior."""
104
+
105
+ @mock.patch("posthog.client.get")
106
+ def test_uses_cached_data_when_should_fetch_returns_false(self, mock_get):
107
+ """When should_fetch returns False and cache has data, use cached data."""
108
+ self.cache_provider.should_fetch_return_value = False
109
+ self.cache_provider.stored_data = self.sample_flags_data
110
+
111
+ client = self._create_client_with_cache()
112
+ client._load_feature_flags()
113
+
114
+ # Should not call API
115
+ mock_get.assert_not_called()
116
+
117
+ # Should have called cache methods
118
+ self.assertEqual(self.cache_provider.should_fetch_call_count, 1)
119
+ self.assertEqual(self.cache_provider.get_call_count, 1)
120
+
121
+ # Flags should be loaded from cache
122
+ self.assertEqual(len(client.feature_flags), 2)
123
+ self.assertEqual(client.feature_flags[0]["key"], "test-flag")
124
+
125
+ client.join()
126
+
127
+ @mock.patch("posthog.client.get")
128
+ def test_fetches_from_api_when_should_fetch_returns_true(self, mock_get):
129
+ """When should_fetch returns True, fetch from API."""
130
+ self.cache_provider.should_fetch_return_value = True
131
+
132
+ mock_get.return_value = GetResponse(
133
+ data=self.sample_flags_data, etag="test-etag", not_modified=False
134
+ )
135
+
136
+ client = self._create_client_with_cache()
137
+ client._load_feature_flags()
138
+
139
+ # Should call API
140
+ mock_get.assert_called_once()
141
+
142
+ # Should have called should_fetch but not get
143
+ self.assertEqual(self.cache_provider.should_fetch_call_count, 1)
144
+ self.assertEqual(self.cache_provider.get_call_count, 0)
145
+
146
+ # Should have called on_received to store in cache
147
+ self.assertEqual(self.cache_provider.on_received_call_count, 1)
148
+
149
+ client.join()
150
+
151
+ @mock.patch("posthog.client.get")
152
+ def test_emergency_fallback_when_cache_empty_and_no_flags(self, mock_get):
153
+ """When should_fetch=False but cache is empty and no flags loaded, fetch anyway."""
154
+ self.cache_provider.should_fetch_return_value = False
155
+ self.cache_provider.stored_data = None # Empty cache
156
+
157
+ mock_get.return_value = GetResponse(
158
+ data=self.sample_flags_data, etag="test-etag", not_modified=False
159
+ )
160
+
161
+ client = self._create_client_with_cache()
162
+ client._load_feature_flags()
163
+
164
+ # Should call API due to emergency fallback
165
+ mock_get.assert_called_once()
166
+
167
+ # Should have called on_received
168
+ self.assertEqual(self.cache_provider.on_received_call_count, 1)
169
+
170
+ client.join()
171
+
172
+ @mock.patch("posthog.client.get")
173
+ def test_preserves_existing_flags_when_cache_returns_none(self, mock_get):
174
+ """When cache returns None but client has flags, preserve existing flags."""
175
+ self.cache_provider.should_fetch_return_value = False
176
+ self.cache_provider.stored_data = None # Empty cache
177
+
178
+ client = self._create_client_with_cache()
179
+
180
+ # Pre-load flags (simulating a previous successful fetch)
181
+ client.feature_flags = self.sample_flags_data["flags"]
182
+ client.group_type_mapping = self.sample_flags_data["group_type_mapping"]
183
+ client.cohorts = self.sample_flags_data["cohorts"]
184
+
185
+ client._load_feature_flags()
186
+
187
+ # Should NOT call API since we already have flags
188
+ mock_get.assert_not_called()
189
+
190
+ # Existing flags should be preserved
191
+ self.assertEqual(len(client.feature_flags), 2)
192
+ self.assertEqual(client.feature_flags[0]["key"], "test-flag")
193
+
194
+ client.join()
195
+
196
+
197
+ class TestFetchCoordination(TestFlagDefinitionCacheProvider):
198
+ """Tests for fetch coordination between workers."""
199
+
200
+ @mock.patch("posthog.client.get")
201
+ def test_calls_should_fetch_before_each_poll(self, mock_get):
202
+ """should_fetch_flag_definitions is called before each poll cycle."""
203
+ self.cache_provider.should_fetch_return_value = True
204
+
205
+ mock_get.return_value = GetResponse(
206
+ data=self.sample_flags_data, etag="test-etag", not_modified=False
207
+ )
208
+
209
+ client = self._create_client_with_cache()
210
+
211
+ # First poll
212
+ client._load_feature_flags()
213
+ self.assertEqual(self.cache_provider.should_fetch_call_count, 1)
214
+
215
+ # Second poll
216
+ client._load_feature_flags()
217
+ self.assertEqual(self.cache_provider.should_fetch_call_count, 2)
218
+
219
+ client.join()
220
+
221
+ @mock.patch("posthog.client.get")
222
+ def test_does_not_call_on_received_when_fetch_skipped(self, mock_get):
223
+ """on_flag_definitions_received is NOT called when fetch is skipped."""
224
+ self.cache_provider.should_fetch_return_value = False
225
+ self.cache_provider.stored_data = self.sample_flags_data
226
+
227
+ client = self._create_client_with_cache()
228
+ client._load_feature_flags()
229
+
230
+ # Should not call on_received since we didn't fetch
231
+ self.assertEqual(self.cache_provider.on_received_call_count, 0)
232
+
233
+ client.join()
234
+
235
+ @mock.patch("posthog.client.get")
236
+ def test_stores_data_in_cache_after_api_fetch(self, mock_get):
237
+ """on_flag_definitions_received receives the fetched data."""
238
+ self.cache_provider.should_fetch_return_value = True
239
+
240
+ mock_get.return_value = GetResponse(
241
+ data=self.sample_flags_data, etag="test-etag", not_modified=False
242
+ )
243
+
244
+ client = self._create_client_with_cache()
245
+ client._load_feature_flags()
246
+
247
+ # Should have stored data in cache
248
+ self.assertEqual(self.cache_provider.on_received_call_count, 1)
249
+ self.assertIsNotNone(self.cache_provider.stored_data)
250
+ self.assertEqual(len(self.cache_provider.stored_data["flags"]), 2)
251
+
252
+ client.join()
253
+
254
+ @mock.patch("posthog.client.get")
255
+ def test_304_not_modified_does_not_update_cache(self, mock_get):
256
+ """When API returns 304 Not Modified, cache should not be updated."""
257
+ self.cache_provider.should_fetch_return_value = True
258
+
259
+ # First fetch to populate flags and ETag
260
+ mock_get.return_value = GetResponse(
261
+ data=self.sample_flags_data, etag="test-etag", not_modified=False
262
+ )
263
+
264
+ client = self._create_client_with_cache()
265
+ client._load_feature_flags()
266
+
267
+ # Verify initial fetch worked
268
+ self.assertEqual(self.cache_provider.on_received_call_count, 1)
269
+ self.assertEqual(len(client.feature_flags), 2)
270
+
271
+ # Second fetch returns 304 Not Modified
272
+ mock_get.return_value = GetResponse(
273
+ data=None, etag="test-etag", not_modified=True
274
+ )
275
+
276
+ client._load_feature_flags()
277
+
278
+ # API was called twice
279
+ self.assertEqual(mock_get.call_count, 2)
280
+
281
+ # should_fetch was called twice
282
+ self.assertEqual(self.cache_provider.should_fetch_call_count, 2)
283
+
284
+ # on_received should NOT be called again (304 = no new data)
285
+ self.assertEqual(self.cache_provider.on_received_call_count, 1)
286
+
287
+ # Flags should still be present
288
+ self.assertEqual(len(client.feature_flags), 2)
289
+
290
+ client.join()
291
+
292
+
293
+ class TestErrorHandling(TestFlagDefinitionCacheProvider):
294
+ """Tests for error handling in cache provider operations."""
295
+
296
+ @mock.patch("posthog.client.get")
297
+ def test_should_fetch_error_defaults_to_fetching(self, mock_get):
298
+ """When should_fetch throws an error, default to fetching from API."""
299
+ self.cache_provider.should_fetch_error = Exception("Lock acquisition failed")
300
+
301
+ mock_get.return_value = GetResponse(
302
+ data=self.sample_flags_data, etag="test-etag", not_modified=False
303
+ )
304
+
305
+ client = self._create_client_with_cache()
306
+ client._load_feature_flags()
307
+
308
+ # Should still fetch from API
309
+ mock_get.assert_called_once()
310
+
311
+ # Flags should be loaded
312
+ self.assertEqual(len(client.feature_flags), 2)
313
+
314
+ client.join()
315
+
316
+ @mock.patch("posthog.client.get")
317
+ def test_get_error_falls_back_to_api_fetch(self, mock_get):
318
+ """When get_flag_definitions throws an error, fetch from API."""
319
+ self.cache_provider.should_fetch_return_value = False
320
+ self.cache_provider.get_error = Exception("Cache read failed")
321
+
322
+ mock_get.return_value = GetResponse(
323
+ data=self.sample_flags_data, etag="test-etag", not_modified=False
324
+ )
325
+
326
+ client = self._create_client_with_cache()
327
+ client._load_feature_flags()
328
+
329
+ # Should fall back to API
330
+ mock_get.assert_called_once()
331
+
332
+ client.join()
333
+
334
+ @mock.patch("posthog.client.get")
335
+ def test_on_received_error_keeps_flags_in_memory(self, mock_get):
336
+ """When on_flag_definitions_received throws, flags are still in memory."""
337
+ self.cache_provider.should_fetch_return_value = True
338
+ self.cache_provider.on_received_error = Exception("Cache write failed")
339
+
340
+ mock_get.return_value = GetResponse(
341
+ data=self.sample_flags_data, etag="test-etag", not_modified=False
342
+ )
343
+
344
+ client = self._create_client_with_cache()
345
+ client._load_feature_flags()
346
+
347
+ # Flags should still be loaded in memory despite cache error
348
+ self.assertEqual(len(client.feature_flags), 2)
349
+ self.assertEqual(client.feature_flags[0]["key"], "test-flag")
350
+
351
+ client.join()
352
+
353
+ @mock.patch("posthog.client.get")
354
+ def test_shutdown_error_is_logged_but_continues(self, mock_get):
355
+ """When shutdown throws an error, it's logged but shutdown continues."""
356
+ self.cache_provider.shutdown_error = Exception("Lock release failed")
357
+
358
+ mock_get.return_value = GetResponse(
359
+ data=self.sample_flags_data, etag="test-etag", not_modified=False
360
+ )
361
+
362
+ client = self._create_client_with_cache()
363
+ client._load_feature_flags()
364
+
365
+ # Should not raise when joining
366
+ client.join()
367
+
368
+ # Shutdown was called
369
+ self.assertEqual(self.cache_provider.shutdown_call_count, 1)
370
+
371
+
372
+ class TestShutdownLifecycle(TestFlagDefinitionCacheProvider):
373
+ """Tests for shutdown lifecycle."""
374
+
375
+ @mock.patch("posthog.client.get")
376
+ def test_shutdown_calls_cache_provider_shutdown(self, mock_get):
377
+ """Client shutdown calls cache provider shutdown."""
378
+ mock_get.return_value = GetResponse(
379
+ data=self.sample_flags_data, etag="test-etag", not_modified=False
380
+ )
381
+
382
+ client = self._create_client_with_cache()
383
+ client._load_feature_flags()
384
+
385
+ # Shutdown
386
+ client.join()
387
+
388
+ self.assertEqual(self.cache_provider.shutdown_call_count, 1)
389
+
390
+ @mock.patch("posthog.client.get")
391
+ def test_shutdown_called_even_without_fetching(self, mock_get):
392
+ """Shutdown is called even when cache was used instead of fetching."""
393
+ self.cache_provider.should_fetch_return_value = False
394
+ self.cache_provider.stored_data = self.sample_flags_data
395
+
396
+ client = self._create_client_with_cache()
397
+ client._load_feature_flags()
398
+ client.join()
399
+
400
+ # Shutdown should still be called
401
+ self.assertEqual(self.cache_provider.shutdown_call_count, 1)
402
+
403
+ @mock.patch("posthog.client.get")
404
+ def test_multiple_join_calls_only_shutdown_once(self, mock_get):
405
+ """Calling join() multiple times should only call cache provider shutdown once."""
406
+ mock_get.return_value = GetResponse(
407
+ data=self.sample_flags_data, etag="test-etag", not_modified=False
408
+ )
409
+
410
+ client = self._create_client_with_cache()
411
+ client._load_feature_flags()
412
+
413
+ # Call join multiple times
414
+ client.join()
415
+ client.join()
416
+ client.join()
417
+
418
+ # Shutdown should be called each time (current behavior - no guard)
419
+ # This test documents the current behavior
420
+ self.assertGreaterEqual(self.cache_provider.shutdown_call_count, 1)
421
+
422
+
423
+ class TestBackwardCompatibility(TestFlagDefinitionCacheProvider):
424
+ """Tests for backward compatibility without cache provider."""
425
+
426
+ @mock.patch("posthog.client.get")
427
+ def test_works_without_cache_provider(self, mock_get):
428
+ """Client works normally without a cache provider configured."""
429
+ mock_get.return_value = GetResponse(
430
+ data=self.sample_flags_data, etag="test-etag", not_modified=False
431
+ )
432
+
433
+ # Create client without cache provider
434
+ client = Client(
435
+ FAKE_TEST_API_KEY,
436
+ personal_api_key="test-personal-key",
437
+ sync_mode=True,
438
+ enable_local_evaluation=False,
439
+ )
440
+ client._load_feature_flags()
441
+
442
+ # Should fetch from API
443
+ mock_get.assert_called_once()
444
+
445
+ # Flags should be loaded
446
+ self.assertEqual(len(client.feature_flags), 2)
447
+
448
+ client.join()
449
+
450
+
451
+ class TestDataIntegrity(TestFlagDefinitionCacheProvider):
452
+ """Tests for data integrity between cache and client state."""
453
+
454
+ @mock.patch("posthog.client.get")
455
+ def test_cached_flags_available_for_evaluation(self, mock_get):
456
+ """Flags loaded from cache are available for local evaluation."""
457
+ self.cache_provider.should_fetch_return_value = False
458
+ self.cache_provider.stored_data = {
459
+ "flags": [
460
+ {
461
+ "key": "test-flag",
462
+ "active": True,
463
+ "filters": {
464
+ "groups": [
465
+ {
466
+ "properties": [],
467
+ "rollout_percentage": 100,
468
+ }
469
+ ]
470
+ },
471
+ }
472
+ ],
473
+ "group_type_mapping": {},
474
+ "cohorts": {},
475
+ }
476
+
477
+ client = self._create_client_with_cache()
478
+ client._load_feature_flags()
479
+
480
+ # Flag should be accessible
481
+ self.assertEqual(len(client.feature_flags), 1)
482
+ self.assertEqual(client.feature_flags_by_key["test-flag"]["key"], "test-flag")
483
+
484
+ client.join()
485
+
486
+ @mock.patch("posthog.client.get")
487
+ def test_group_type_mapping_loaded_from_cache(self, mock_get):
488
+ """Group type mapping is correctly loaded from cache."""
489
+ self.cache_provider.should_fetch_return_value = False
490
+ self.cache_provider.stored_data = self.sample_flags_data
491
+
492
+ client = self._create_client_with_cache()
493
+ client._load_feature_flags()
494
+
495
+ self.assertEqual(client.group_type_mapping["0"], "company")
496
+ self.assertEqual(client.group_type_mapping["1"], "project")
497
+
498
+ client.join()
499
+
500
+ @mock.patch("posthog.client.get")
501
+ def test_cohorts_loaded_from_cache(self, mock_get):
502
+ """Cohorts are correctly loaded from cache."""
503
+ self.cache_provider.should_fetch_return_value = False
504
+ self.cache_provider.stored_data = self.sample_flags_data
505
+
506
+ client = self._create_client_with_cache()
507
+ client._load_feature_flags()
508
+
509
+ self.assertIn("1", client.cohorts)
510
+
511
+ client.join()
512
+
513
+ @mock.patch("posthog.client.get")
514
+ def test_cache_updated_when_api_returns_new_data(self, mock_get):
515
+ """State transition: cache has old data -> API returns new -> cache updated."""
516
+ # Start with old cached data
517
+ old_flags_data: FlagDefinitionCacheData = {
518
+ "flags": [{"key": "old-flag", "active": True, "filters": {}}],
519
+ "group_type_mapping": {},
520
+ "cohorts": {},
521
+ }
522
+ self.cache_provider.stored_data = old_flags_data
523
+ self.cache_provider.should_fetch_return_value = False
524
+
525
+ client = self._create_client_with_cache()
526
+
527
+ # First load from cache
528
+ client._load_feature_flags()
529
+ self.assertEqual(client.feature_flags[0]["key"], "old-flag")
530
+ self.assertEqual(self.cache_provider.on_received_call_count, 0)
531
+
532
+ # Now trigger API fetch with new data
533
+ self.cache_provider.should_fetch_return_value = True
534
+ new_flags_data: FlagDefinitionCacheData = {
535
+ "flags": [{"key": "new-flag", "active": True, "filters": {}}],
536
+ "group_type_mapping": {"0": "company"},
537
+ "cohorts": {"1": {"properties": []}},
538
+ }
539
+ mock_get.return_value = GetResponse(
540
+ data=new_flags_data, etag="new-etag", not_modified=False
541
+ )
542
+
543
+ client._load_feature_flags()
544
+
545
+ # Verify new flags loaded
546
+ self.assertEqual(client.feature_flags[0]["key"], "new-flag")
547
+ self.assertEqual(client.group_type_mapping["0"], "company")
548
+
549
+ # Verify cache was updated
550
+ self.assertEqual(self.cache_provider.on_received_call_count, 1)
551
+ self.assertEqual(self.cache_provider.stored_data["flags"][0]["key"], "new-flag")
552
+
553
+ client.join()
554
+
555
+
556
+ class TestConcurrency(TestFlagDefinitionCacheProvider):
557
+ """Tests for thread safety and concurrent access."""
558
+
559
+ @mock.patch("posthog.client.get")
560
+ def test_concurrent_load_feature_flags_is_thread_safe(self, mock_get):
561
+ """Multiple threads calling _load_feature_flags should not cause errors."""
562
+ mock_get.return_value = GetResponse(
563
+ data=self.sample_flags_data, etag="test-etag", not_modified=False
564
+ )
565
+
566
+ client = self._create_client_with_cache()
567
+ errors = []
568
+
569
+ def load_flags():
570
+ try:
571
+ client._load_feature_flags()
572
+ except Exception as e:
573
+ errors.append(e)
574
+
575
+ # Launch 5 threads concurrently
576
+ threads = [threading.Thread(target=load_flags) for _ in range(5)]
577
+ for t in threads:
578
+ t.start()
579
+ for t in threads:
580
+ t.join()
581
+
582
+ # Should complete without errors
583
+ self.assertEqual(len(errors), 0, f"Unexpected errors: {errors}")
584
+
585
+ # Flags should be loaded
586
+ self.assertIsNotNone(client.feature_flags)
587
+ self.assertEqual(len(client.feature_flags), 2)
588
+
589
+ client.join()
590
+
591
+
592
+ class TestProtocolCompliance(unittest.TestCase):
593
+ """Tests for Protocol compliance."""
594
+
595
+ def test_mock_provider_is_protocol_instance(self):
596
+ """MockCacheProvider satisfies FlagDefinitionCacheProvider protocol."""
597
+ provider = MockCacheProvider()
598
+ self.assertIsInstance(provider, FlagDefinitionCacheProvider)
599
+
600
+ def test_incomplete_provider_is_not_protocol_instance(self):
601
+ """Class missing methods is not a FlagDefinitionCacheProvider."""
602
+
603
+ class IncompleteProvider:
604
+ def get_flag_definitions(self):
605
+ return None
606
+
607
+ provider = IncompleteProvider()
608
+ self.assertNotIsInstance(provider, FlagDefinitionCacheProvider)
609
+
610
+
611
+ if __name__ == "__main__":
612
+ unittest.main()