Skip to content

Commit d0aa0d9

Browse files
[minor] Adding mqtt method to IOTX509RefreshableSession (#80)
* mqtt method * adding logging * Docs update * readme * Adding missing dependency. * readme changes update * final readme update
1 parent 254a241 commit d0aa0d9

File tree

5 files changed

+260
-35
lines changed

5 files changed

+260
-35
lines changed

README.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@
9999
- **STS**
100100
- **IoT Core**
101101
- X.509 certificates w/ role aliases over mTLS (PEM files and PKCS#11)
102+
- MQTT actions are available!
102103
- Custom authentication methods
103104
- Natively supports all parameters supported by `boto3.session.Session`
104105
- [Tested](https://github.com/michaelthomasletts/boto3-refresh-session/tree/main/tests), [documented](https://michaelthomasletts.github.io/boto3-refresh-session/index.html), and [published to PyPI](https://pypi.org/project/boto3-refresh-session/)
@@ -317,6 +318,8 @@ pip install boto3-refresh-session
317318

318319
## ⚠️ Changes
319320

321+
Browse through the various changes to `boto3-refresh-session` over time.
322+
320323
#### 😥 v3.0.0
321324

322325
**The changes introduced by v3.0.0 will not impact ~99% of users** who generally interact with `boto3-refresh-session` by only `RefreshableSession`, *which is the intended usage for this package after all.*
@@ -331,4 +334,8 @@ The `ecs` module has been dropped. For additional details and rationale, please
331334

332335
#### 😛 v5.0.0
333336

334-
Support for IoT Core via X.509 certificate-based authentication (over HTTPS) is now available!
337+
Support for IoT Core via X.509 certificate-based authentication (over HTTPS) is now available!
338+
339+
#### ➕ v5.1.0
340+
341+
MQTT support added for IoT Core via X.509 certificate-based authentication.

boto3_refresh_session/methods/iot/x509.py

Lines changed: 232 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,42 @@
22

33
import json
44
import re
5+
from atexit import register
56
from pathlib import Path
6-
from typing import cast
7+
from tempfile import NamedTemporaryFile
8+
from typing import cast, get_args
79
from urllib.parse import ParseResult, urlparse
810

11+
from awscrt import auth, io
912
from awscrt.exceptions import AwsCrtError
1013
from awscrt.http import HttpClientConnection, HttpRequest
1114
from awscrt.io import (
1215
ClientBootstrap,
1316
ClientTlsContext,
1417
DefaultHostResolver,
1518
EventLoopGroup,
19+
LogLevel,
1620
Pkcs11Lib,
1721
TlsConnectionOptions,
1822
TlsContextOptions,
23+
init_logging,
1924
)
25+
from awscrt.mqtt import Connection
26+
from awsiot import mqtt_connection_builder
2027

2128
from ...exceptions import BRSError, BRSWarning
2229
from ...utils import (
2330
PKCS11,
2431
AWSCRTResponse,
2532
Identity,
2633
TemporaryCredentials,
34+
Transport,
2735
refreshable_session,
2836
)
2937
from .core import BaseIoTRefreshableSession
3038

39+
_TEMP_PATHS: list[str] = []
40+
3141

3242
@refreshable_session
3343
class IOTX509RefreshableSession(
@@ -71,6 +81,9 @@ class IOTX509RefreshableSession(
7181
The duration for which the temporary credentials are valid, in
7282
seconds. Cannot exceed the value declared in the IAM policy.
7383
Default is None.
84+
awscrt_log_level : awscrt.LogLevel | None, optional
85+
The logging level for the AWS CRT library, e.g.
86+
``awscrt.LogLevel.INFO``. Default is None.
7487
7588
Notes
7689
-----
@@ -90,34 +103,34 @@ def __init__(
90103
verify_peer: bool = True,
91104
timeout: float | int | None = None,
92105
duration_seconds: int | None = None,
106+
awscrt_log_level: LogLevel | None = None,
93107
**kwargs,
94108
):
95109
# initializing BRSSession
96110
super().__init__(refresh_method="iot-x509", **kwargs)
97111

112+
# logging
113+
if awscrt_log_level:
114+
init_logging(log_level=awscrt_log_level, file_name="stdout")
115+
98116
# initializing public attributes
99117
self.endpoint = self._normalize_iot_credential_endpoint(
100118
endpoint=endpoint
101119
)
102120
self.role_alias = role_alias
103-
self.certificate = certificate
121+
self.certificate = self._read_maybe_path_to_bytes(
122+
certificate, fallback=None, name="certificate"
123+
)
104124
self.thing_name = thing_name
105-
self.private_key = private_key
106-
self.pkcs11 = pkcs11
107-
self.ca = ca
125+
self.private_key = self._read_maybe_path_to_bytes(
126+
private_key, fallback=None, name="private_key"
127+
)
128+
self.pkcs11 = self._validate_pkcs11(pkcs11) if pkcs11 else None
129+
self.ca = self._read_maybe_path_to_bytes(ca, fallback=None, name="ca")
108130
self.verify_peer = verify_peer
109131
self.timeout = 10.0 if timeout is None else timeout
110132
self.duration_seconds = duration_seconds
111133

112-
# loading X.509 certificate if presented as a string, which
113-
# is presumed to be the file path.
114-
# if presented as bytes then self.certificate is presumed to be
115-
# the actual certificate itself
116-
if self.certificate and isinstance(self.certificate, str):
117-
self.certificate = (
118-
Path(self.certificate).expanduser().resolve().read_bytes()
119-
)
120-
121134
# either private_key or pkcs11 must be provided
122135
if self.private_key is None and self.pkcs11 is None:
123136
raise BRSError(
@@ -130,22 +143,6 @@ def __init__(
130143
"Only one of 'private_key' or 'pkcs11' can be provided."
131144
)
132145

133-
# if the provided private_key is bytes then it's presumed to be
134-
# the actual private key. but if it's string then it's presumed
135-
# to be the file path
136-
if self.private_key and isinstance(self.private_key, str):
137-
self.private_key = (
138-
Path(self.private_key).expanduser().resolve().read_bytes()
139-
)
140-
141-
# verifying PKCS#11 dict
142-
if self.pkcs11:
143-
self.pkcs11 = self._validate_pkcs11(pkcs11=self.pkcs11)
144-
145-
# ca is like many other attributes in that str implies file location
146-
if self.ca and isinstance(self.ca, str):
147-
self.ca = Path(self.ca).expanduser().resolve().read_bytes()
148-
149146
def _get_credentials(self) -> TemporaryCredentials:
150147
url = urlparse(
151148
f"https://{self.endpoint}/role-aliases/{self.role_alias}"
@@ -334,3 +331,208 @@ def _validate_pkcs11(pkcs11: PKCS11) -> PKCS11:
334331
pkcs11.setdefault("token_label", None)
335332
pkcs11.setdefault("private_key_label", None)
336333
return pkcs11
334+
335+
@staticmethod
336+
def _read_maybe_path_to_bytes(
337+
v: str | bytes | None, fallback: bytes | None, name: str
338+
) -> bytes | None:
339+
match v:
340+
case None:
341+
return fallback
342+
case bytes():
343+
return v
344+
case str() as p if Path(p).expanduser().resolve().is_file():
345+
return Path(p).expanduser().resolve().read_bytes()
346+
case _:
347+
raise BRSError(f"Invalid {name} provided.")
348+
349+
@staticmethod
350+
def _bytes_to_tempfile(b: bytes, suffix: str = ".pem") -> str:
351+
f = NamedTemporaryFile("wb", suffix=suffix, delete=False)
352+
f.write(b)
353+
f.flush()
354+
f.close()
355+
_TEMP_PATHS.append(f.name)
356+
return f.name
357+
358+
@staticmethod
359+
@register
360+
def _cleanup_tempfiles():
361+
for p in _TEMP_PATHS:
362+
try:
363+
Path(p).unlink(missing_ok=True)
364+
except Exception:
365+
...
366+
367+
def mqtt(
368+
self,
369+
*,
370+
endpoint: str,
371+
client_id: str,
372+
transport: Transport = "x509",
373+
certificate: str | bytes | None = None,
374+
private_key: str | bytes | None = None,
375+
ca: str | bytes | None = None,
376+
pkcs11: PKCS11 | None = None,
377+
region: str | None = None,
378+
keep_alive_secs: int = 60,
379+
clean_start: bool = True,
380+
port: int | None = None,
381+
use_alpn: bool = False,
382+
) -> Connection:
383+
"""Establishes an MQTT connection using the specified parameters.
384+
385+
Parameters
386+
----------
387+
endpoint: str
388+
The MQTT endpoint to connect to.
389+
client_id: str
390+
The client ID to use for the MQTT connection.
391+
transport: Transport
392+
The transport protocol to use (e.g., "x509" or "ws").
393+
certificate: str | bytes | None
394+
The client certificate to use for the connection. Defaults to the
395+
session certificate.
396+
private_key: str | bytes | None
397+
The private key to use for the connection. Defaults to the
398+
session private key.
399+
ca: str | bytes | None
400+
The CA certificate to use for the connection. Defaults to the
401+
session CA certificate.
402+
pkcs11: PKCS11 | None
403+
PKCS#11 configuration for hardware-backed keys. Defaults to the
404+
session PKCS#11 configuration.
405+
region: str | None
406+
The AWS region to use for the connection. Defaults to the
407+
session region.
408+
keep_alive_secs: int
409+
The keep-alive interval for the MQTT connection. Default is 60
410+
seconds.
411+
clean_start: bool
412+
Whether to start a clean session. Default is True.
413+
port: int | None
414+
The port to use for the MQTT connection. Default is 8883 if not
415+
using ALPN, otherwise 443.
416+
use_alpn: bool
417+
Whether to use ALPN for the connection. Default is False.
418+
419+
Returns
420+
-------
421+
awscrt.mqtt.Connection
422+
The established MQTT connection.
423+
"""
424+
425+
# Validate transport
426+
if transport not in list(get_args(Transport)):
427+
raise BRSError("Transport must be 'x509' or 'ws'")
428+
429+
# Region default (WS only)
430+
if region is None:
431+
region = self.region_name
432+
433+
# Normalize inputs to bytes using session defaults
434+
cert_bytes = self._read_maybe_path_to_bytes(
435+
certificate, getattr(self, "certificate", None), "certificate"
436+
)
437+
key_bytes = self._read_maybe_path_to_bytes(
438+
private_key, getattr(self, "private_key", None), "private_key"
439+
)
440+
ca_bytes = self._read_maybe_path_to_bytes(
441+
ca, getattr(self, "ca", None), "ca"
442+
)
443+
444+
# Validate PKCS#11
445+
match pkcs11:
446+
case None:
447+
pkcs11 = getattr(self, "pkcs11", None)
448+
case dict():
449+
pkcs11 = self._validate_pkcs11(pkcs11)
450+
case _:
451+
raise BRSError("Invalid PKCS#11 configuration provided.")
452+
453+
# X.509 invariants
454+
if transport == "x509":
455+
has_key = key_bytes is not None
456+
has_hsm = pkcs11 is not None
457+
if not has_key and not has_hsm:
458+
raise BRSError(
459+
"For transport='x509', provide either 'private_key' "
460+
"(bytes/path) or 'pkcs11'."
461+
)
462+
if has_key and has_hsm:
463+
raise BRSError(
464+
"Provide only one of 'private_key' or 'pkcs11' for "
465+
"transport='x509'."
466+
)
467+
if cert_bytes is None:
468+
raise BRSError("Certificate is required for transport='x509'")
469+
470+
# CRT bootstrap
471+
event_loop = io.EventLoopGroup(1)
472+
host_resolver = io.DefaultHostResolver(event_loop)
473+
bootstrap = io.ClientBootstrap(event_loop, host_resolver)
474+
475+
# Build connection
476+
if transport == "x509":
477+
if pkcs11 is not None:
478+
# Cert must be a filepath for PKCS#11 builder → write temp
479+
cert_path = self._bytes_to_tempfile(
480+
cast(bytes, cert_bytes), ".crt"
481+
)
482+
ca_path = (
483+
self._bytes_to_tempfile(ca_bytes, ".pem")
484+
if ca_bytes
485+
else None
486+
)
487+
488+
return mqtt_connection_builder.mtls_with_pkcs11(
489+
endpoint=endpoint,
490+
client_bootstrap=bootstrap,
491+
pkcs11_lib=Pkcs11Lib(file=pkcs11["pkcs11_lib"]),
492+
user_pin=pkcs11.get("user_pin"),
493+
slot_id=pkcs11.get("slot_id"),
494+
token_label=pkcs11.get("token_label"),
495+
private_key_object=pkcs11.get("private_key_label"),
496+
cert_filepath=cert_path,
497+
ca_filepath=ca_path,
498+
client_id=client_id,
499+
clean_session=clean_start,
500+
keep_alive_secs=keep_alive_secs,
501+
port=port or (443 if use_alpn else 8883),
502+
alpn_list=["x-amzn-mqtt-ca"] if use_alpn else None,
503+
)
504+
else:
505+
# pure mTLS with in-memory cert/key/CA
506+
return mqtt_connection_builder.mtls_from_bytes(
507+
endpoint=endpoint,
508+
cert_bytes=cert_bytes,
509+
pri_key_bytes=key_bytes,
510+
ca_bytes=ca_bytes,
511+
client_bootstrap=bootstrap,
512+
client_id=client_id,
513+
clean_session=clean_start,
514+
keep_alive_secs=keep_alive_secs,
515+
port=port or (443 if use_alpn else 8883),
516+
alpn_list=["x-amzn-mqtt-ca"] if use_alpn else None,
517+
)
518+
519+
else: # transport == "ws"
520+
# WebSockets + SigV4
521+
creds_provider = auth.AwsCredentialsProvider.new_delegate(
522+
self._credentials
523+
)
524+
ca_path = (
525+
self._bytes_to_tempfile(ca_bytes, ".pem") if ca_bytes else None
526+
)
527+
528+
return mqtt_connection_builder.websockets_with_default_aws_signing(
529+
endpoint=endpoint,
530+
client_bootstrap=bootstrap,
531+
region=region,
532+
credentials_provider=creds_provider,
533+
client_id=client_id,
534+
clean_session=clean_start,
535+
keep_alive_secs=keep_alive_secs,
536+
ca_filepath=ca_path,
537+
port=port or 443,
538+
)

boto3_refresh_session/utils/typing.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"STSClientParams",
1414
"TemporaryCredentials",
1515
"RefreshableTemporaryCredentials",
16+
"Transport",
1617
]
1718

1819
from datetime import datetime
@@ -57,6 +58,9 @@
5758
#: Type alias for values returned by get_identity
5859
Identity: TypeAlias = dict[str, Any]
5960

61+
#: Type alias for acceptable transports
62+
Transport: TypeAlias = Literal["x509", "ws"]
63+
6064

6165
class TemporaryCredentials(TypedDict):
6266
"""Temporary IAM credentials."""

0 commit comments

Comments
 (0)