Skip to content

Commit bcdf60d

Browse files
committed
squash! squash! Fix DatabaseSyncToAsync
Move it to DatabaseSyncToAsyncForTests
1 parent f701174 commit bcdf60d

File tree

1 file changed

+29
-16
lines changed

1 file changed

+29
-16
lines changed

channels/db.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,39 @@
11
from django.db import connections
2+
from django.db import close_old_connections
23

34
from asgiref.sync import SyncToAsync
45

5-
main_thread_connections = {name: connections[name] for name in connections}
6-
7-
8-
def _inherit_main_thread_connections():
9-
"""Copy/use DB connections in atomic block from main thread.
10-
11-
This is required for tests using Django's TestCase.
12-
"""
13-
for name in main_thread_connections:
14-
if main_thread_connections[name].in_atomic_block:
15-
connections[name] = main_thread_connections[name]
16-
connections[name].inc_thread_sharing()
17-
186

197
class DatabaseSyncToAsync(SyncToAsync):
208
"""
219
SyncToAsync version that cleans up old database connections.
2210
"""
2311

12+
def thread_handler(self, loop, *args, **kwargs):
13+
close_old_connections()
14+
try:
15+
return super().thread_handler(loop, *args, **kwargs)
16+
finally:
17+
close_old_connections()
18+
19+
20+
class DatabaseSyncToAsyncForTests(SyncToAsync):
21+
def __init__(self, *args, **kwargs):
22+
self.main_thread_connections = {name: connections[name] for name in connections}
23+
super().__init__(*args, **kwargs)
24+
25+
def _inherit_main_thread_connections(self):
26+
"""Copy/use DB connections in atomic block from main thread.
27+
28+
This is required for tests using Django's TestCase.
29+
"""
30+
from django.db import connections
31+
32+
for name in self.main_thread_connections:
33+
if self.main_thread_connections[name].in_atomic_block:
34+
connections[name] = self.main_thread_connections[name]
35+
connections[name].inc_thread_sharing()
36+
2437
def _close_old_connections(self):
2538
"""Like django.db.close_old_connections, but skipping in_atomic_block."""
2639
for conn in connections.all():
@@ -29,12 +42,12 @@ def _close_old_connections(self):
2942
conn.close_if_unusable_or_obsolete()
3043

3144
def thread_handler(self, loop, *args, **kwargs):
32-
_inherit_main_thread_connections()
33-
self._close_old_connections()
45+
self._inherit_main_thread_connections()
46+
close_old_connections()
3447
try:
3548
return super().thread_handler(loop, *args, **kwargs)
3649
finally:
37-
self._close_old_connections()
50+
close_old_connections()
3851

3952

4053
# The class is TitleCased, but we want to encourage use as a callable/decorator

0 commit comments

Comments
 (0)