Skip to content

Commit 796855e

Browse files
committed
INTPYTHON-527 Add Queryable Encryption support
1 parent d5aa1a7 commit 796855e

28 files changed

+1181
-10
lines changed

.evergreen/config.yml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,28 @@ buildvariants:
9090
tasks:
9191
- name: run-tests
9292

93+
- name: tests-7-noauth-nossl
94+
display_name: Run Tests 7.0 NoAuth NoSSL
95+
run_on: rhel87-small
96+
expansions:
97+
MONGODB_VERSION: "7.0"
98+
TOPOLOGY: server
99+
AUTH: "noauth"
100+
SSL: "nossl"
101+
tasks:
102+
- name: run-tests
103+
104+
- name: tests-7-auth-ssl
105+
display_name: Run Tests 7.0 Auth SSL
106+
run_on: rhel87-small
107+
expansions:
108+
MONGODB_VERSION: "7.0"
109+
TOPOLOGY: server
110+
AUTH: "auth"
111+
SSL: "ssl"
112+
tasks:
113+
- name: run-tests
114+
93115
- name: tests-8-noauth-nossl
94116
display_name: Run Tests 8.0 NoAuth NoSSL
95117
run_on: rhel87-small

django_mongodb_backend/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .indexes import register_indexes # noqa: E402
1515
from .lookups import register_lookups # noqa: E402
1616
from .query import register_nodes # noqa: E402
17+
from .routers import register_routers # noqa: E402
1718

1819
__all__ = ["parse_uri"]
1920

@@ -25,3 +26,4 @@
2526
register_indexes()
2627
register_lookups()
2728
register_nodes()
29+
register_routers()

django_mongodb_backend/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,4 +229,7 @@ def cursor(self):
229229

230230
def get_database_version(self):
231231
"""Return a tuple of the database's version."""
232-
return tuple(self.connection.server_info()["versionArray"])
232+
# Avoid using PyMongo to check the database version or require
233+
# pymongocrypt>=1.14.2 which will contain a fix for the `buildInfo`
234+
# command. https://jira.mongodb.org/browse/PYTHON-5429
235+
return tuple(self.connection.admin.command("buildInfo")["versionArray"])

django_mongodb_backend/encryption.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Queryable Encryption helper classes and settings
2+
3+
import os
4+
5+
from .fields import has_encrypted_fields
6+
7+
KMS_CREDENTIALS = {
8+
"aws": {
9+
"key": os.getenv("AWS_KEY_ARN", ""),
10+
"region": os.getenv("AWS_KEY_REGION", ""),
11+
},
12+
"azure": {
13+
"keyName": os.getenv("AZURE_KEY_NAME", ""),
14+
"keyVaultEndpoint": os.getenv("AZURE_KEY_VAULT_ENDPOINT", ""),
15+
},
16+
"gcp": {
17+
"projectId": os.getenv("GCP_PROJECT_ID", ""),
18+
"location": os.getenv("GCP_LOCATION", ""),
19+
"keyRing": os.getenv("GCP_KEY_RING", ""),
20+
"keyName": os.getenv("GCP_KEY_NAME", ""),
21+
},
22+
"kmip": {},
23+
"local": {},
24+
}
25+
26+
KMS_PROVIDERS = {
27+
"aws": {},
28+
"azure": {},
29+
"gcp": {},
30+
"kmip": {
31+
"endpoint": os.getenv("KMIP_KMS_ENDPOINT", "not a valid endpoint"),
32+
},
33+
"local": {
34+
"key": bytes.fromhex(
35+
"000102030405060708090a0b0c0d0e0f"
36+
"101112131415161718191a1b1c1d1e1f"
37+
"202122232425262728292a2b2c2d2e2f"
38+
"303132333435363738393a3b3c3d3e3f"
39+
"404142434445464748494a4b4c4d4e4f"
40+
"505152535455565758595a5b5c5d5e5f"
41+
)
42+
},
43+
}
44+
45+
46+
class EncryptedRouter:
47+
def allow_migrate(self, db, app_label, model_name=None, model=None, **hints):
48+
if model:
49+
return db == ("other" if has_encrypted_fields(model) else "default")
50+
return db == "default"
51+
52+
def db_for_read(self, model, **hints):
53+
if has_encrypted_fields(model):
54+
return "other"
55+
return "default"
56+
57+
db_for_write = db_for_read
58+
59+
def kms_provider(self, model):
60+
return "local"
61+
62+
63+
class EqualityQuery(dict):
64+
def __init__(self, *, contention=None):
65+
super().__init__(queryType="equality")
66+
if contention is not None:
67+
self["contention"] = contention
68+
69+
70+
class RangeQuery(dict):
71+
def __init__(
72+
self, *, contention=None, max=None, min=None, precision=None, sparsity=None, trimFactor=None
73+
):
74+
super().__init__(queryType="range")
75+
options = {
76+
"contention": contention,
77+
"max": max,
78+
"min": min,
79+
"precision": precision,
80+
"sparsity": sparsity,
81+
"trimFactor": trimFactor,
82+
}
83+
self.update({k: v for k, v in options.items() if v is not None})

django_mongodb_backend/features.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,18 @@ class DatabaseFeatures(BaseDatabaseFeatures):
9797
"expressions.tests.ExpressionOperatorTests.test_lefthand_transformed_field_bitwise_or",
9898
}
9999

100+
_django_test_expected_failures_queryable_encryption = {
101+
"encryption.tests.EncryptedFieldTests.test_get_encrypted_fields_map_method",
102+
"encryption.tests.EncryptedFieldTests.test_get_encrypted_fields_map_command",
103+
"encryption.tests.EncryptedFieldTests.test_set_encrypted_fields_map_in_client",
104+
"encryption.tests.EncryptedFieldTests.test_appointment",
105+
"encryption.tests.EncryptedFieldTests.test_billing",
106+
"encryption.tests.EncryptedFieldTests.test_patientportaluser",
107+
"encryption.tests.EncryptedFieldTests.test_patientrecord",
108+
"encryption.tests.EncryptedFieldTests.test_patient",
109+
"encryption.tests.EncryptedFieldTests.test_env",
110+
}
111+
100112
@cached_property
101113
def django_test_expected_failures(self):
102114
expected_failures = super().django_test_expected_failures
@@ -569,9 +581,17 @@ def django_test_expected_failures(self):
569581
},
570582
}
571583

584+
@cached_property
585+
def mongodb_version(self):
586+
return self.connection.get_database_version() # e.g., (6, 3, 0)
587+
572588
@cached_property
573589
def is_mongodb_6_3(self):
574-
return self.connection.get_database_version() >= (6, 3)
590+
return self.mongodb_version >= (6, 3)
591+
592+
@cached_property
593+
def is_mongodb_7_0(self):
594+
return self.mongodb_version >= (7, 0)
575595

576596
@cached_property
577597
def supports_atlas_search(self):
@@ -601,3 +621,20 @@ def _supports_transactions(self):
601621
hello = client.command("hello")
602622
# a replica set or a sharded cluster
603623
return "setName" in hello or hello.get("msg") == "isdbgrid"
624+
625+
@cached_property
626+
def supports_queryable_encryption(self):
627+
"""
628+
Queryable Encryption is supported if the server is Atlas or Enterprise
629+
and is configured as a replica set or a sharded cluster.
630+
"""
631+
self.connection.ensure_connection()
632+
client = self.connection.connection.admin
633+
build_info = client.command("buildInfo")
634+
is_enterprise = "enterprise" in build_info.get("modules")
635+
# Queryable Encryption requires transaction support which
636+
# is only available on replica sets or sharded clusters
637+
# which we already check in `supports_transactions`.
638+
supports_transactions = self.supports_transactions
639+
# TODO: check if the server is Atlas
640+
return is_enterprise and supports_transactions and self.is_mongodb_7_0

django_mongodb_backend/fields/__init__.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,50 @@
33
from .duration import register_duration_field
44
from .embedded_model import EmbeddedModelField
55
from .embedded_model_array import EmbeddedModelArrayField
6+
from .encryption import (
7+
EncryptedBigIntegerField,
8+
EncryptedBinaryField,
9+
EncryptedBooleanField,
10+
EncryptedCharField,
11+
EncryptedDateField,
12+
EncryptedDateTimeField,
13+
EncryptedDecimalField,
14+
EncryptedEmailField,
15+
EncryptedFieldMixin,
16+
EncryptedFloatField,
17+
EncryptedGenericIPAddressField,
18+
EncryptedIntegerField,
19+
EncryptedTextField,
20+
EncryptedTimeField,
21+
EncryptedURLField,
22+
has_encrypted_fields,
23+
)
624
from .json import register_json_field
725
from .objectid import ObjectIdField
826
from .polymorphic_embedded_model import PolymorphicEmbeddedModelField
927
from .polymorphic_embedded_model_array import PolymorphicEmbeddedModelArrayField
1028

1129
__all__ = [
30+
"has_encrypted_fields",
31+
"register_fields",
1232
"ArrayField",
1333
"EmbeddedModelArrayField",
1434
"EmbeddedModelField",
35+
"EncryptedBigIntegerField",
36+
"EncryptedBinaryField",
37+
"EncryptedBooleanField",
38+
"EncryptedCharField",
39+
"EncryptedDateTimeField",
40+
"EncryptedDateField",
41+
"EncryptedDecimalField",
42+
"EncryptedEmailField",
43+
"EncryptedFieldMixin",
44+
"EncryptedFloatField",
45+
"EncryptedGenericIPAddressField",
46+
"EncryptedIntegerField",
47+
"EncryptedTextField",
48+
"EncryptedTimeField",
49+
"EncryptedURLField",
1550
"ObjectIdAutoField",
1651
"ObjectIdField",
1752
"PolymorphicEmbeddedModelArrayField",
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from django.db import models
2+
3+
4+
def has_encrypted_fields(model):
5+
return any(getattr(field, "encrypted", False) for field in model._meta.fields)
6+
7+
8+
class EncryptedFieldMixin(models.Field):
9+
encrypted = True
10+
11+
def __init__(self, *args, queries=None, **kwargs):
12+
self.queries = queries
13+
super().__init__(*args, **kwargs)
14+
15+
def deconstruct(self):
16+
name, path, args, kwargs = super().deconstruct()
17+
18+
if self.queries is not None:
19+
kwargs["queries"] = self.queries
20+
21+
if path.startswith("django_mongodb_backend.fields.encrypted_model"):
22+
path = path.replace(
23+
"django_mongodb_backend.fields.encrypted_model",
24+
"django_mongodb_backend.fields",
25+
)
26+
27+
return name, path, args, kwargs
28+
29+
30+
class EncryptedBigIntegerField(EncryptedFieldMixin, models.BigIntegerField):
31+
pass
32+
33+
34+
class EncryptedBinaryField(EncryptedFieldMixin, models.BinaryField):
35+
pass
36+
37+
38+
class EncryptedBooleanField(EncryptedFieldMixin, models.BooleanField):
39+
pass
40+
41+
42+
class EncryptedCharField(EncryptedFieldMixin, models.CharField):
43+
pass
44+
45+
46+
class EncryptedDateField(EncryptedFieldMixin, models.DateField):
47+
pass
48+
49+
50+
class EncryptedDateTimeField(EncryptedFieldMixin, models.DateTimeField):
51+
pass
52+
53+
54+
class EncryptedDecimalField(EncryptedFieldMixin, models.DecimalField):
55+
pass
56+
57+
58+
class EncryptedEmailField(EncryptedFieldMixin, models.EmailField):
59+
pass
60+
61+
62+
class EncryptedFloatField(EncryptedFieldMixin, models.FloatField):
63+
pass
64+
65+
66+
class EncryptedGenericIPAddressField(EncryptedFieldMixin, models.GenericIPAddressField):
67+
pass
68+
69+
70+
class EncryptedIntegerField(EncryptedFieldMixin, models.IntegerField):
71+
pass
72+
73+
74+
class EncryptedSlugField(EncryptedFieldMixin, models.SlugField):
75+
pass
76+
77+
78+
class EncryptedTimeField(EncryptedFieldMixin, models.TimeField):
79+
pass
80+
81+
82+
class EncryptedTextField(EncryptedFieldMixin, models.TextField):
83+
pass
84+
85+
86+
class EncryptedURLField(EncryptedFieldMixin, models.URLField):
87+
pass
88+
89+
90+
# TODO: Add more encrypted fields
91+
# - PositiveBigIntegerField
92+
# - PositiveIntegerField
93+
# - PositiveSmallIntegerField
94+
# - SmallIntegerField
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from bson import json_util
2+
from django.apps import apps
3+
from django.core.management.base import BaseCommand
4+
from django.db import DEFAULT_DB_ALIAS, connections, router
5+
from pymongo.encryption import ClientEncryption
6+
7+
8+
class Command(BaseCommand):
9+
help = "Generate a `schema_map` of encrypted fields for all encrypted"
10+
" models in the database for use with `AutoEncryptionOpts` in"
11+
" client configuration."
12+
13+
def add_arguments(self, parser):
14+
parser.add_argument(
15+
"--database",
16+
default=DEFAULT_DB_ALIAS,
17+
help="Specify the database to use for generating the encrypted"
18+
"fields map. Defaults to the 'default' database.",
19+
)
20+
parser.add_argument(
21+
"--kms-provider",
22+
default="local",
23+
help="Specify the KMS provider to use for encryption. Defaults to 'local'.",
24+
)
25+
26+
def handle(self, *args, **options):
27+
db = options["database"]
28+
kms_provider = options["kms_provider"]
29+
connection = connections[db]
30+
schema_map = json_util.dumps(
31+
self.get_encrypted_fields_map(connection, kms_provider), indent=2
32+
)
33+
self.stdout.write(schema_map)
34+
35+
def get_client_encryption(self, connection):
36+
client = connection.connection
37+
options = client._options.auto_encryption_opts
38+
key_vault_namespace = options._key_vault_namespace
39+
kms_providers = options._kms_providers
40+
return ClientEncryption(kms_providers, key_vault_namespace, client, client.codec_options)
41+
42+
def get_encrypted_fields_map(self, connection, kms_provider):
43+
schema_map = {}
44+
for app_config in apps.get_app_configs():
45+
for model in router.get_migratable_models(
46+
app_config, connection.settings_dict["NAME"], include_auto_created=False
47+
):
48+
if getattr(model, "encrypted", False):
49+
fields = connection.schema_editor()._get_encrypted_fields_map(model)
50+
ce = self.get_client_encryption(connection)
51+
master_key = connection.settings_dict.get("KMS_CREDENTIALS").get(kms_provider)
52+
for field in fields["fields"]:
53+
data_key = ce.create_data_key(
54+
kms_provider=kms_provider,
55+
master_key=master_key,
56+
)
57+
field["keyId"] = data_key
58+
schema_map[model._meta.db_table] = fields
59+
return schema_map

0 commit comments

Comments
 (0)