Skip to content

Commit a6d61da

Browse files
authored
feat: add db alias (#21)
* feat: add db alias * feat: add db alias * docs: add CASBIN_DB_ALIAS argument to readme
1 parent 15e7887 commit a6d61da

File tree

7 files changed

+41
-26
lines changed

7 files changed

+41
-26
lines changed

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def hello(request):
7070
A string containing the file location of your casbin model.
7171

7272
### `CASBIN_ADAPTER`
73-
A string containing the adapter import path. Defaults to the django adapter shipped with this package: `casbin_adapter.adapter.Adapter`
73+
A string containing the adapter import path. Default to the django adapter shipped with this package: `casbin_adapter.adapter.Adapter`
7474

7575
### `CASBIN_ADAPTER_ARGS`
7676
A tuple of arguments to be passed into the constructor of the adapter specified
@@ -80,6 +80,9 @@ E.g. if you wish to use the file adapter
8080
set the adapter to `casbin.persist.adapters.FileAdapter` and use
8181
`CASBIN_ADAPTER_ARGS = ('path/to/policy_file.csv',)`
8282

83+
### `CASBIN_DB_ALIAS`
84+
The database the adapter uses. Default to "default".
85+
8386
### `CASBIN_WATCHER`
8487
Watcher instance to be set as the watcher on the enforcer instance.
8588

casbin_adapter/adapter.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@
1111
class Adapter(persist.Adapter):
1212
"""the interface for Casbin adapters."""
1313

14+
def __init__(self, db_alias="default"):
15+
self.db_alias = db_alias
16+
1417
def load_policy(self, model):
1518
"""loads all policy rules from the storage."""
1619
try:
17-
lines = CasbinRule.objects.all()
20+
lines = CasbinRule.objects.using(self.db_alias).all()
1821

1922
for line in lines:
2023
persist.load_policy_line(str(line), model)
@@ -41,7 +44,7 @@ def save_policy(self, model):
4144
"""saves all policy rules to the storage."""
4245
# See https://casbin.org/docs/en/adapters#autosave
4346
# for why this is deleting all rules
44-
CasbinRule.objects.all().delete()
47+
CasbinRule.objects.using(self.db_alias).all().delete()
4548

4649
lines = []
4750
for sec in ["p", "g"]:
@@ -50,7 +53,7 @@ def save_policy(self, model):
5053
for ptype, ast in model.model[sec].items():
5154
for rule in ast.policy:
5255
lines.append(self._create_policy_line(ptype, rule))
53-
CasbinRule.objects.bulk_create(lines)
56+
CasbinRule.objects.using(self.db_alias).bulk_create(lines)
5457
return True
5558

5659
def add_policy(self, sec, ptype, rule):
@@ -63,7 +66,7 @@ def remove_policy(self, sec, ptype, rule):
6366
query_params = {"ptype": ptype}
6467
for i, v in enumerate(rule):
6568
query_params["v{}".format(i)] = v
66-
rows_deleted, _ = CasbinRule.objects.filter(**query_params).delete()
69+
rows_deleted, _ = CasbinRule.objects.using(self.db_alias).filter(**query_params).delete()
6770
return True if rows_deleted > 0 else False
6871

6972
def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
@@ -77,5 +80,5 @@ def remove_filtered_policy(self, sec, ptype, field_index, *field_values):
7780
return False
7881
for i, v in enumerate(field_values):
7982
query_params["v{}".format(i + field_index)] = v
80-
rows_deleted, _ = CasbinRule.objects.filter(**query_params).delete()
83+
rows_deleted, _ = CasbinRule.objects.using(self.db_alias).filter(**query_params).delete()
8184
return True if rows_deleted > 0 else False

casbin_adapter/apps.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
from django.apps import AppConfig
2-
from django.db import connection
3-
from django.db.utils import OperationalError, ProgrammingError
42

53

64
class CasbinAdapterConfig(AppConfig):

casbin_adapter/enforcer.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55

66
from casbin import Enforcer
77

8-
from .adapter import Adapter
98
from .utils import import_class
109

1110
logger = logging.getLogger(__name__)
1211

1312

1413
class ProxyEnforcer(Enforcer):
1514
_initialized = False
15+
db_alias = "default"
1616

1717
def __init__(self, *args, **kwargs):
1818
if self._initialized:
@@ -27,8 +27,9 @@ def _load(self):
2727
model = getattr(settings, "CASBIN_MODEL")
2828
adapter_loc = getattr(settings, "CASBIN_ADAPTER", "casbin_adapter.adapter.Adapter")
2929
adapter_args = getattr(settings, "CASBIN_ADAPTER_ARGS", tuple())
30+
self.db_alias = getattr(settings, "CASBIN_DB_ALIAS", "default")
3031
Adapter = import_class(adapter_loc)
31-
adapter = Adapter(*adapter_args)
32+
adapter = Adapter(self.db_alias, *adapter_args)
3233

3334
super().__init__(model, adapter)
3435
logger.debug("Casbin enforcer initialised")
@@ -44,7 +45,7 @@ def _load(self):
4445
def __getattribute__(self, name):
4546
safe_methods = ["__init__", "_load", "_initialized"]
4647
if not super().__getattribute__("_initialized") and name not in safe_methods:
47-
initialize_enforcer()
48+
initialize_enforcer(self.db_alias)
4849
if not super().__getattribute__("_initialized"):
4950
raise Exception(
5051
(
@@ -59,17 +60,29 @@ def __getattribute__(self, name):
5960
enforcer = ProxyEnforcer()
6061

6162

62-
def initialize_enforcer():
63+
def initialize_enforcer(db_alias=None):
6364
try:
64-
with connection.cursor() as cursor:
65-
cursor.execute(
66-
"""
67-
SELECT app, name applied FROM django_migrations
68-
WHERE app = 'casbin_adapter' AND name = '0001_initial';
69-
"""
70-
)
71-
row = cursor.fetchone()
72-
if row:
73-
enforcer._load()
65+
row = None
66+
if db_alias:
67+
with connection[db_alias].cursor() as cursor:
68+
cursor.execute(
69+
"""
70+
SELECT app, name applied FROM django_migrations
71+
WHERE app = 'casbin_adapter' AND name = '0001_initial';
72+
"""
73+
)
74+
row = cursor.fetchone()
75+
else:
76+
with connection.cursor() as cursor:
77+
cursor.execute(
78+
"""
79+
SELECT app, name applied FROM django_migrations
80+
WHERE app = 'casbin_adapter' AND name = '0001_initial';
81+
"""
82+
)
83+
row = cursor.fetchone()
84+
85+
if row:
86+
enforcer._load()
7487
except (OperationalError, ProgrammingError):
7588
pass

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
casbin==1.16.10
1+
casbin>=1.16.10
22
Django
33

requirements_dev.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
-r requirements.txt
2-
setuptools==60.2.0
3-
simpleeval==0.9.12
2+
setuptools==60.2.0

tests/test_adapter.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
import casbin
33
import simpleeval
4-
from unittest import TestCase
54

65
from django.test import TestCase
76
from casbin_adapter.models import CasbinRule

0 commit comments

Comments
 (0)