Skip to content

Commit f0d4e92

Browse files
committed
INTPYTHON-527 Add Queryable Encryption support
1 parent d5aa1a7 commit f0d4e92

27 files changed

+1123
-9
lines changed

.evergreen/config.yml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,28 @@ buildvariants:
9090
tasks:
9191
- name: run-tests
9292

93+
- name: tests-7-noauth-nossl
94+
display_name: Run Tests 7.0 NoAuth NoSSL
95+
run_on: rhel87-small
96+
expansions:
97+
MONGODB_VERSION: "7.0"
98+
TOPOLOGY: server
99+
AUTH: "noauth"
100+
SSL: "nossl"
101+
tasks:
102+
- name: run-tests
103+
104+
- name: tests-7-auth-ssl
105+
display_name: Run Tests 7.0 Auth SSL
106+
run_on: rhel87-small
107+
expansions:
108+
MONGODB_VERSION: "7.0"
109+
TOPOLOGY: server
110+
AUTH: "auth"
111+
SSL: "ssl"
112+
tasks:
113+
- name: run-tests
114+
93115
- name: tests-8-noauth-nossl
94116
display_name: Run Tests 8.0 NoAuth NoSSL
95117
run_on: rhel87-small

django_mongodb_backend/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .indexes import register_indexes # noqa: E402
1515
from .lookups import register_lookups # noqa: E402
1616
from .query import register_nodes # noqa: E402
17+
from .routers import register_routers # noqa: E402
1718

1819
__all__ = ["parse_uri"]
1920

@@ -25,3 +26,4 @@
2526
register_indexes()
2627
register_lookups()
2728
register_nodes()
29+
register_routers()

django_mongodb_backend/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,4 +229,7 @@ def cursor(self):
229229

230230
def get_database_version(self):
231231
"""Return a tuple of the database's version."""
232-
return tuple(self.connection.server_info()["versionArray"])
232+
# Avoid using PyMongo to check the database version or require
233+
# pymongocrypt>=1.14.2 which will contain a fix for the `buildInfo`
234+
# command. https://jira.mongodb.org/browse/PYTHON-5429
235+
return tuple(self.connection.admin.command("buildInfo")["versionArray"])

django_mongodb_backend/encryption.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Queryable Encryption helper classes
2+
3+
4+
class EqualityQuery(dict):
5+
def __init__(self, *, contention=None):
6+
super().__init__(queryType="equality")
7+
if contention is not None:
8+
self["contention"] = contention
9+
10+
11+
class RangeQuery(dict):
12+
def __init__(
13+
self, *, contention=None, max=None, min=None, precision=None, sparsity=None, trimFactor=None
14+
):
15+
super().__init__(queryType="range")
16+
options = {
17+
"contention": contention,
18+
"max": max,
19+
"min": min,
20+
"precision": precision,
21+
"sparsity": sparsity,
22+
"trimFactor": trimFactor,
23+
}
24+
self.update({k: v for k, v in options.items() if v is not None})

django_mongodb_backend/features.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,9 +569,17 @@ def django_test_expected_failures(self):
569569
},
570570
}
571571

572+
@cached_property
573+
def mongodb_version(self):
574+
return self.connection.get_database_version() # e.g., (6, 3, 0)
575+
572576
@cached_property
573577
def is_mongodb_6_3(self):
574-
return self.connection.get_database_version() >= (6, 3)
578+
return self.mongodb_version >= (6, 3)
579+
580+
@cached_property
581+
def is_mongodb_7_0(self):
582+
return self.mongodb_version >= (7, 0)
575583

576584
@cached_property
577585
def supports_atlas_search(self):
@@ -601,3 +609,18 @@ def _supports_transactions(self):
601609
hello = client.command("hello")
602610
# a replica set or a sharded cluster
603611
return "setName" in hello or hello.get("msg") == "isdbgrid"
612+
613+
@cached_property
614+
def supports_queryable_encryption(self):
615+
"""
616+
Queryable Encryption requires a MongoDB 7.0 or later replica set or sharded
617+
cluster, as well as MonogDB Atlas or Enterprise.
618+
"""
619+
self.connection.ensure_connection()
620+
build_info = self.connection.connection.admin.command("buildInfo")
621+
is_enterprise = "enterprise" in build_info.get("modules")
622+
return (
623+
(is_enterprise or self.supports_atlas_search)
624+
and self._supports_transactions
625+
and self.is_mongodb_7_0
626+
)

django_mongodb_backend/fields/__init__.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,28 @@
33
from .duration import register_duration_field
44
from .embedded_model import EmbeddedModelField
55
from .embedded_model_array import EmbeddedModelArrayField
6+
from .encryption import (
7+
EncryptedBigIntegerField,
8+
EncryptedBinaryField,
9+
EncryptedBooleanField,
10+
EncryptedCharField,
11+
EncryptedDateField,
12+
EncryptedDateTimeField,
13+
EncryptedDecimalField,
14+
EncryptedEmailField,
15+
EncryptedFieldMixin,
16+
EncryptedFloatField,
17+
EncryptedGenericIPAddressField,
18+
EncryptedIntegerField,
19+
EncryptedPositiveBigIntegerField,
20+
EncryptedPositiveIntegerField,
21+
EncryptedPositiveSmallIntegerField,
22+
EncryptedSmallIntegerField,
23+
EncryptedTextField,
24+
EncryptedTimeField,
25+
EncryptedURLField,
26+
has_encrypted_fields,
27+
)
628
from .json import register_json_field
729
from .objectid import ObjectIdField
830
from .polymorphic_embedded_model import PolymorphicEmbeddedModelField
@@ -12,10 +34,31 @@
1234
"ArrayField",
1335
"EmbeddedModelArrayField",
1436
"EmbeddedModelField",
37+
"EncryptedBigIntegerField",
38+
"EncryptedBinaryField",
39+
"EncryptedBooleanField",
40+
"EncryptedCharField",
41+
"EncryptedDateField",
42+
"EncryptedDateTimeField",
43+
"EncryptedDecimalField",
44+
"EncryptedEmailField",
45+
"EncryptedFieldMixin",
46+
"EncryptedFloatField",
47+
"EncryptedGenericIPAddressField",
48+
"EncryptedIntegerField",
49+
"EncryptedPositiveBigIntegerField",
50+
"EncryptedPositiveIntegerField",
51+
"EncryptedPositiveSmallIntegerField",
52+
"EncryptedSmallIntegerField",
53+
"EncryptedTextField",
54+
"EncryptedTimeField",
55+
"EncryptedURLField",
1556
"ObjectIdAutoField",
1657
"ObjectIdField",
1758
"PolymorphicEmbeddedModelArrayField",
1859
"PolymorphicEmbeddedModelField",
60+
"has_encrypted_fields",
61+
"register_fields",
1962
"register_fields",
2063
]
2164

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from django.db import models
2+
3+
4+
def has_encrypted_fields(model):
5+
return any(getattr(field, "encrypted", False) for field in model._meta.fields)
6+
7+
8+
class EncryptedFieldMixin(models.Field):
9+
encrypted = True
10+
11+
def __init__(self, *args, queries=None, **kwargs):
12+
self.queries = queries
13+
super().__init__(*args, **kwargs)
14+
15+
def deconstruct(self):
16+
name, path, args, kwargs = super().deconstruct()
17+
18+
if self.queries is not None:
19+
kwargs["queries"] = self.queries
20+
21+
if path.startswith("django_mongodb_backend.fields.encrypted_model"):
22+
path = path.replace(
23+
"django_mongodb_backend.fields.encrypted_model",
24+
"django_mongodb_backend.fields",
25+
)
26+
27+
return name, path, args, kwargs
28+
29+
30+
class EncryptedBigIntegerField(EncryptedFieldMixin, models.BigIntegerField):
31+
pass
32+
33+
34+
class EncryptedBinaryField(EncryptedFieldMixin, models.BinaryField):
35+
pass
36+
37+
38+
class EncryptedBooleanField(EncryptedFieldMixin, models.BooleanField):
39+
pass
40+
41+
42+
class EncryptedCharField(EncryptedFieldMixin, models.CharField):
43+
pass
44+
45+
46+
class EncryptedDateField(EncryptedFieldMixin, models.DateField):
47+
pass
48+
49+
50+
class EncryptedDateTimeField(EncryptedFieldMixin, models.DateTimeField):
51+
pass
52+
53+
54+
class EncryptedDecimalField(EncryptedFieldMixin, models.DecimalField):
55+
pass
56+
57+
58+
class EncryptedEmailField(EncryptedFieldMixin, models.EmailField):
59+
pass
60+
61+
62+
class EncryptedFloatField(EncryptedFieldMixin, models.FloatField):
63+
pass
64+
65+
66+
class EncryptedGenericIPAddressField(EncryptedFieldMixin, models.GenericIPAddressField):
67+
pass
68+
69+
70+
class EncryptedIntegerField(EncryptedFieldMixin, models.IntegerField):
71+
pass
72+
73+
74+
class EncryptedPositiveBigIntegerField(EncryptedFieldMixin, models.PositiveBigIntegerField):
75+
pass
76+
77+
78+
class EncryptedPositiveIntegerField(EncryptedFieldMixin, models.PositiveIntegerField):
79+
pass
80+
81+
82+
class EncryptedPositiveSmallIntegerField(EncryptedFieldMixin, models.PositiveSmallIntegerField):
83+
pass
84+
85+
86+
class EncryptedSmallIntegerField(EncryptedFieldMixin, models.SmallIntegerField):
87+
pass
88+
89+
90+
class EncryptedTimeField(EncryptedFieldMixin, models.TimeField):
91+
pass
92+
93+
94+
class EncryptedTextField(EncryptedFieldMixin, models.TextField):
95+
pass
96+
97+
98+
class EncryptedURLField(EncryptedFieldMixin, models.URLField):
99+
pass
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from bson import json_util
2+
from django.apps import apps
3+
from django.core.management.base import BaseCommand
4+
from django.db import DEFAULT_DB_ALIAS, connections, router
5+
from pymongo.encryption import ClientEncryption
6+
7+
8+
class Command(BaseCommand):
9+
help = "Generate a `schema_map` of encrypted fields for all encrypted"
10+
" models in the database for use with `AutoEncryptionOpts` in"
11+
" client configuration."
12+
13+
def add_arguments(self, parser):
14+
parser.add_argument(
15+
"--database",
16+
default=DEFAULT_DB_ALIAS,
17+
help="Specify the database to use for generating the encrypted"
18+
"fields map. Defaults to the 'default' database.",
19+
)
20+
parser.add_argument(
21+
"--kms-provider",
22+
default="local",
23+
help="Specify the KMS provider to use for encryption. Defaults to 'local'.",
24+
)
25+
26+
def handle(self, *args, **options):
27+
db = options["database"]
28+
kms_provider = options["kms_provider"]
29+
connection = connections[db]
30+
schema_map = json_util.dumps(
31+
self.get_encrypted_fields_map(connection, kms_provider), indent=2
32+
)
33+
self.stdout.write(schema_map)
34+
35+
def get_client_encryption(self, connection):
36+
client = connection.connection
37+
options = client._options.auto_encryption_opts
38+
key_vault_namespace = options._key_vault_namespace
39+
kms_providers = options._kms_providers
40+
return ClientEncryption(kms_providers, key_vault_namespace, client, client.codec_options)
41+
42+
def get_encrypted_fields_map(self, connection, kms_provider):
43+
schema_map = {}
44+
for app_config in apps.get_app_configs():
45+
for model in router.get_migratable_models(
46+
app_config, connection.settings_dict["NAME"], include_auto_created=False
47+
):
48+
if getattr(model, "encrypted", False):
49+
fields = connection.schema_editor()._get_encrypted_fields_map(model)
50+
ce = self.get_client_encryption(connection)
51+
master_key = connection.settings_dict.get("KMS_CREDENTIALS").get(kms_provider)
52+
for field in fields["fields"]:
53+
data_key = ce.create_data_key(
54+
kms_provider=kms_provider,
55+
master_key=master_key,
56+
)
57+
field["keyId"] = data_key
58+
schema_map[model._meta.db_table] = fields
59+
return schema_map

django_mongodb_backend/routers.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from django.apps import apps
2+
from django.core.exceptions import ImproperlyConfigured
3+
from django.db.utils import ConnectionRouter
24

3-
from django_mongodb_backend.models import EmbeddedModel
5+
from .fields import has_encrypted_fields
46

57

68
class MongoRouter:
@@ -9,10 +11,32 @@ def allow_migrate(self, db, app_label, model_name=None, **hints):
911
EmbeddedModels don't have their own collection and must be ignored by
1012
dumpdata.
1113
"""
14+
from django_mongodb_backend.models import EmbeddedModel # noqa: PLC0415
15+
1216
if not model_name:
1317
return None
1418
try:
1519
model = apps.get_model(app_label, model_name)
1620
except LookupError:
1721
return None
22+
1823
return False if issubclass(model, EmbeddedModel) else None
24+
25+
26+
def kms_provider(self, model, *args, **kwargs):
27+
for router in self.routers:
28+
func = getattr(router, "kms_provider", None)
29+
if func and callable(func):
30+
result = func(model, *args, **kwargs)
31+
if result is not None:
32+
return result
33+
if has_encrypted_fields(model):
34+
raise ImproperlyConfigured("No kms_provider found in database router.")
35+
return None
36+
37+
38+
def register_routers():
39+
"""
40+
Patch the ConnectionRouter to use the custom kms_provider method.
41+
"""
42+
ConnectionRouter.kms_provider = kms_provider

0 commit comments

Comments
 (0)