Skip to content

Commit 10ae20e

Browse files
committed
INTPYTHON-527 Add Queryable Encryption support
1 parent 4ca6c90 commit 10ae20e

29 files changed

+1205
-14
lines changed

.evergreen/config.yml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,28 @@ buildvariants:
9090
tasks:
9191
- name: run-tests
9292

93+
- name: tests-7-noauth-nossl
94+
display_name: Run Tests 7.0 NoAuth NoSSL
95+
run_on: rhel87-small
96+
expansions:
97+
MONGODB_VERSION: "7.0"
98+
TOPOLOGY: server
99+
AUTH: "noauth"
100+
SSL: "nossl"
101+
tasks:
102+
- name: run-tests
103+
104+
- name: tests-7-auth-ssl
105+
display_name: Run Tests 7.0 Auth SSL
106+
run_on: rhel87-small
107+
expansions:
108+
MONGODB_VERSION: "7.0"
109+
TOPOLOGY: server
110+
AUTH: "auth"
111+
SSL: "ssl"
112+
tasks:
113+
- name: run-tests
114+
93115
- name: tests-8-noauth-nossl
94116
display_name: Run Tests 8.0 NoAuth NoSSL
95117
run_on: rhel87-small

django_mongodb_backend/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .indexes import register_indexes # noqa: E402
1515
from .lookups import register_lookups # noqa: E402
1616
from .query import register_nodes # noqa: E402
17+
from .routers import register_routers # noqa: E402
1718

1819
__all__ = ["parse_uri"]
1920

@@ -25,3 +26,4 @@
2526
register_indexes()
2627
register_lookups()
2728
register_nodes()
29+
register_routers()

django_mongodb_backend/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,4 +286,7 @@ def validate_no_broken_transaction(self):
286286

287287
def get_database_version(self):
288288
"""Return a tuple of the database's version."""
289-
return tuple(self.connection.server_info()["versionArray"])
289+
# Avoid using PyMongo to check the database version or require
290+
# pymongocrypt>=1.14.2 which will contain a fix for the `buildInfo`
291+
# command. https://jira.mongodb.org/browse/PYTHON-5429
292+
return tuple(self.connection.admin.command("buildInfo")["versionArray"])

django_mongodb_backend/encryption.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Queryable Encryption helper classes and settings
2+
3+
import os
4+
5+
KMS_CREDENTIALS = {
6+
"aws": {
7+
"key": os.getenv("AWS_KEY_ARN", ""),
8+
"region": os.getenv("AWS_KEY_REGION", ""),
9+
},
10+
"azure": {
11+
"keyName": os.getenv("AZURE_KEY_NAME", ""),
12+
"keyVaultEndpoint": os.getenv("AZURE_KEY_VAULT_ENDPOINT", ""),
13+
},
14+
"gcp": {
15+
"projectId": os.getenv("GCP_PROJECT_ID", ""),
16+
"location": os.getenv("GCP_LOCATION", ""),
17+
"keyRing": os.getenv("GCP_KEY_RING", ""),
18+
"keyName": os.getenv("GCP_KEY_NAME", ""),
19+
},
20+
"kmip": {},
21+
"local": {},
22+
}
23+
24+
KMS_PROVIDERS = {
25+
"aws": {},
26+
"azure": {},
27+
"gcp": {},
28+
"kmip": {
29+
"endpoint": os.getenv("KMIP_KMS_ENDPOINT", "not a valid endpoint"),
30+
},
31+
"local": {
32+
"key": bytes.fromhex(
33+
"000102030405060708090a0b0c0d0e0f"
34+
"101112131415161718191a1b1c1d1e1f"
35+
"202122232425262728292a2b2c2d2e2f"
36+
"303132333435363738393a3b3c3d3e3f"
37+
"404142434445464748494a4b4c4d4e4f"
38+
"505152535455565758595a5b5c5d5e5f"
39+
)
40+
},
41+
}
42+
43+
44+
class EncryptedRouter:
45+
def allow_migrate(self, db, app_label, model_name=None, model=None, **hints):
46+
if model:
47+
return db == ("other" if getattr(model, "encrypted", False) else "default")
48+
return db == "default"
49+
50+
def db_for_read(self, model, **hints):
51+
if getattr(model, "encrypted", False):
52+
return "other"
53+
return "default"
54+
55+
db_for_write = db_for_read
56+
57+
def kms_provider(self, model):
58+
return "local"
59+
60+
61+
class EqualityQuery(dict):
62+
def __init__(self, *, contention=None):
63+
super().__init__(queryType="equality")
64+
if contention is not None:
65+
self["contention"] = contention
66+
67+
68+
class RangeQuery(dict):
69+
def __init__(
70+
self, *, contention=None, max=None, min=None, precision=None, sparsity=None, trimFactor=None
71+
):
72+
super().__init__(queryType="range")
73+
options = {
74+
"contention": contention,
75+
"max": max,
76+
"min": min,
77+
"precision": precision,
78+
"sparsity": sparsity,
79+
"trimFactor": trimFactor,
80+
}
81+
self.update({k: v for k, v in options.items() if v is not None})

django_mongodb_backend/features.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,17 @@ class DatabaseFeatures(BaseDatabaseFeatures):
122122
"multiple_database.tests.QueryTestCase.test_generic_key_cross_database_protection",
123123
"multiple_database.tests.QueryTestCase.test_m2m_cross_database_protection",
124124
}
125+
_django_test_expected_failures_queryable_encryption = {
126+
"encryption.tests.EncryptedFieldTests.test_get_encrypted_fields_map_method",
127+
"encryption.tests.EncryptedFieldTests.test_get_encrypted_fields_map_command",
128+
"encryption.tests.EncryptedFieldTests.test_set_encrypted_fields_map_in_client",
129+
"encryption.tests.EncryptedFieldTests.test_appointment",
130+
"encryption.tests.EncryptedFieldTests.test_billing",
131+
"encryption.tests.EncryptedFieldTests.test_patientportaluser",
132+
"encryption.tests.EncryptedFieldTests.test_patientrecord",
133+
"encryption.tests.EncryptedFieldTests.test_patient",
134+
"encryption.tests.EncryptedFieldTests.test_env",
135+
}
125136

126137
@cached_property
127138
def django_test_expected_failures(self):
@@ -588,9 +599,17 @@ def django_test_expected_failures(self):
588599
},
589600
}
590601

602+
@cached_property
603+
def mongodb_version(self):
604+
return self.connection.get_database_version() # e.g., (6, 3, 0)
605+
591606
@cached_property
592607
def is_mongodb_6_3(self):
593-
return self.connection.get_database_version() >= (6, 3)
608+
return self.mongodb_version >= (6, 3)
609+
610+
@cached_property
611+
def is_mongodb_7_0(self):
612+
return self.mongodb_version >= (7, 0)
594613

595614
@cached_property
596615
def supports_atlas_search(self):
@@ -624,3 +643,20 @@ def supports_transactions(self):
624643
hello = client.command("hello")
625644
# a replica set or a sharded cluster
626645
return "setName" in hello or hello.get("msg") == "isdbgrid"
646+
647+
@cached_property
648+
def supports_queryable_encryption(self):
649+
"""
650+
Queryable Encryption is supported if the server is Atlas or Enterprise
651+
and is configured as a replica set or a sharded cluster.
652+
"""
653+
self.connection.ensure_connection()
654+
client = self.connection.connection.admin
655+
build_info = client.command("buildInfo")
656+
is_enterprise = "enterprise" in build_info.get("modules")
657+
# Queryable Encryption requires transaction support which
658+
# is only available on replica sets or sharded clusters
659+
# which we already check in `supports_transactions`.
660+
supports_transactions = self.supports_transactions
661+
# TODO: check if the server is Atlas
662+
return is_enterprise and supports_transactions and self.is_mongodb_7_0

django_mongodb_backend/fields/__init__.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,23 @@
33
from .duration import register_duration_field
44
from .embedded_model import EmbeddedModelField
55
from .embedded_model_array import EmbeddedModelArrayField
6+
from .encrypted_model import (
7+
EncryptedBigIntegerField,
8+
EncryptedBinaryField,
9+
EncryptedBooleanField,
10+
EncryptedCharField,
11+
EncryptedDateField,
12+
EncryptedDateTimeField,
13+
EncryptedDecimalField,
14+
EncryptedEmailField,
15+
EncryptedFieldMixin,
16+
EncryptedFloatField,
17+
EncryptedGenericIPAddressField,
18+
EncryptedIntegerField,
19+
EncryptedTextField,
20+
EncryptedTimeField,
21+
EncryptedURLField,
22+
)
623
from .json import register_json_field
724
from .objectid import ObjectIdField
825

@@ -11,6 +28,21 @@
1128
"ArrayField",
1229
"EmbeddedModelArrayField",
1330
"EmbeddedModelField",
31+
"EncryptedBigIntegerField",
32+
"EncryptedBinaryField",
33+
"EncryptedBooleanField",
34+
"EncryptedCharField",
35+
"EncryptedDateTimeField",
36+
"EncryptedDateField",
37+
"EncryptedDecimalField",
38+
"EncryptedEmailField",
39+
"EncryptedFieldMixin",
40+
"EncryptedFloatField",
41+
"EncryptedGenericIPAddressField",
42+
"EncryptedIntegerField",
43+
"EncryptedTextField",
44+
"EncryptedTimeField",
45+
"EncryptedURLField",
1446
"ObjectIdAutoField",
1547
"ObjectIdField",
1648
]
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from django.db import models
2+
3+
4+
class EncryptedFieldMixin(models.Field):
5+
encrypted = True
6+
7+
def __init__(self, *args, queries=None, **kwargs):
8+
self.queries = queries
9+
super().__init__(*args, **kwargs)
10+
11+
def deconstruct(self):
12+
name, path, args, kwargs = super().deconstruct()
13+
14+
if self.queries is not None:
15+
kwargs["queries"] = self.queries
16+
17+
if path.startswith("django_mongodb_backend.fields.encrypted_model"):
18+
path = path.replace(
19+
"django_mongodb_backend.fields.encrypted_model",
20+
"django_mongodb_backend.fields",
21+
)
22+
23+
return name, path, args, kwargs
24+
25+
26+
class EncryptedBigIntegerField(EncryptedFieldMixin, models.BigIntegerField):
27+
pass
28+
29+
30+
class EncryptedBinaryField(EncryptedFieldMixin, models.BinaryField):
31+
pass
32+
33+
34+
class EncryptedBooleanField(EncryptedFieldMixin, models.BooleanField):
35+
pass
36+
37+
38+
class EncryptedCharField(EncryptedFieldMixin, models.CharField):
39+
pass
40+
41+
42+
class EncryptedDateField(EncryptedFieldMixin, models.DateField):
43+
pass
44+
45+
46+
class EncryptedDateTimeField(EncryptedFieldMixin, models.DateTimeField):
47+
pass
48+
49+
50+
class EncryptedDecimalField(EncryptedFieldMixin, models.DecimalField):
51+
pass
52+
53+
54+
class EncryptedEmailField(EncryptedFieldMixin, models.EmailField):
55+
pass
56+
57+
58+
class EncryptedFloatField(EncryptedFieldMixin, models.FloatField):
59+
pass
60+
61+
62+
class EncryptedGenericIPAddressField(EncryptedFieldMixin, models.GenericIPAddressField):
63+
pass
64+
65+
66+
class EncryptedIntegerField(EncryptedFieldMixin, models.IntegerField):
67+
pass
68+
69+
70+
class EncryptedSlugField(EncryptedFieldMixin, models.SlugField):
71+
pass
72+
73+
74+
class EncryptedTimeField(EncryptedFieldMixin, models.TimeField):
75+
pass
76+
77+
78+
class EncryptedTextField(EncryptedFieldMixin, models.TextField):
79+
pass
80+
81+
82+
class EncryptedURLField(EncryptedFieldMixin, models.URLField):
83+
pass
84+
85+
86+
# TODO: Add more encrypted fields
87+
# - PositiveBigIntegerField
88+
# - PositiveIntegerField
89+
# - PositiveSmallIntegerField
90+
# - SmallIntegerField
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from bson import json_util
2+
from django.apps import apps
3+
from django.core.management.base import BaseCommand
4+
from django.db import DEFAULT_DB_ALIAS, connections, router
5+
from pymongo.encryption import ClientEncryption
6+
7+
8+
class Command(BaseCommand):
9+
help = "Generate a `schema_map` of encrypted fields for all encrypted"
10+
" models in the database for use with `AutoEncryptionOpts` in"
11+
" client configuration."
12+
13+
def add_arguments(self, parser):
14+
parser.add_argument(
15+
"--database",
16+
default=DEFAULT_DB_ALIAS,
17+
help="Specify the database to use for generating the encrypted"
18+
"fields map. Defaults to the 'default' database.",
19+
)
20+
parser.add_argument(
21+
"--kms-provider",
22+
default="local",
23+
help="Specify the KMS provider to use for encryption. Defaults to 'local'.",
24+
)
25+
26+
def handle(self, *args, **options):
27+
db = options["database"]
28+
kms_provider = options["kms_provider"]
29+
connection = connections[db]
30+
schema_map = json_util.dumps(
31+
self.get_encrypted_fields_map(connection, kms_provider), indent=2
32+
)
33+
self.stdout.write(schema_map)
34+
35+
def get_client_encryption(self, connection):
36+
client = connection.connection
37+
options = client._options.auto_encryption_opts
38+
key_vault_namespace = options._key_vault_namespace
39+
kms_providers = options._kms_providers
40+
return ClientEncryption(kms_providers, key_vault_namespace, client, client.codec_options)
41+
42+
def get_encrypted_fields_map(self, connection, kms_provider):
43+
schema_map = {}
44+
for app_config in apps.get_app_configs():
45+
for model in router.get_migratable_models(
46+
app_config, connection.settings_dict["NAME"], include_auto_created=False
47+
):
48+
if getattr(model, "encrypted", False):
49+
fields = connection.schema_editor()._get_encrypted_fields_map(model)
50+
ce = self.get_client_encryption(connection)
51+
master_key = connection.settings_dict.get("KMS_CREDENTIALS").get(kms_provider)
52+
for field in fields["fields"]:
53+
data_key = ce.create_data_key(
54+
kms_provider=kms_provider,
55+
master_key=master_key,
56+
)
57+
field["keyId"] = data_key
58+
schema_map[model._meta.db_table] = fields
59+
return schema_map

django_mongodb_backend/models.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,10 @@ def delete(self, *args, **kwargs):
1414

1515
def save(self, *args, **kwargs):
1616
raise NotSupportedError("EmbeddedModels cannot be saved.")
17+
18+
19+
class EncryptedModel(models.Model):
20+
encrypted = True
21+
22+
class Meta:
23+
abstract = True

0 commit comments

Comments
 (0)