2
2
3
3
import json
4
4
import re
5
+ from atexit import register
5
6
from pathlib import Path
6
- from typing import cast
7
+ from tempfile import NamedTemporaryFile
8
+ from typing import cast , get_args
7
9
from urllib .parse import ParseResult , urlparse
8
10
11
+ from awscrt import auth , io
9
12
from awscrt .exceptions import AwsCrtError
10
13
from awscrt .http import HttpClientConnection , HttpRequest
11
14
from awscrt .io import (
12
15
ClientBootstrap ,
13
16
ClientTlsContext ,
14
17
DefaultHostResolver ,
15
18
EventLoopGroup ,
19
+ LogLevel ,
16
20
Pkcs11Lib ,
17
21
TlsConnectionOptions ,
18
22
TlsContextOptions ,
23
+ init_logging ,
19
24
)
25
+ from awscrt .mqtt import Connection
26
+ from awsiot import mqtt_connection_builder
20
27
21
28
from ...exceptions import BRSError , BRSWarning
22
29
from ...utils import (
23
30
PKCS11 ,
24
31
AWSCRTResponse ,
25
32
Identity ,
26
33
TemporaryCredentials ,
34
+ Transport ,
27
35
refreshable_session ,
28
36
)
29
37
from .core import BaseIoTRefreshableSession
30
38
39
+ _TEMP_PATHS : list [str ] = []
40
+
31
41
32
42
@refreshable_session
33
43
class IOTX509RefreshableSession (
@@ -71,6 +81,9 @@ class IOTX509RefreshableSession(
71
81
The duration for which the temporary credentials are valid, in
72
82
seconds. Cannot exceed the value declared in the IAM policy.
73
83
Default is None.
84
+ awscrt_log_level : awscrt.LogLevel | None, optional
85
+ The logging level for the AWS CRT library, e.g.
86
+ ``awscrt.LogLevel.INFO``. Default is None.
74
87
75
88
Notes
76
89
-----
@@ -90,34 +103,34 @@ def __init__(
90
103
verify_peer : bool = True ,
91
104
timeout : float | int | None = None ,
92
105
duration_seconds : int | None = None ,
106
+ awscrt_log_level : LogLevel | None = None ,
93
107
** kwargs ,
94
108
):
95
109
# initializing BRSSession
96
110
super ().__init__ (refresh_method = "iot-x509" , ** kwargs )
97
111
112
+ # logging
113
+ if awscrt_log_level :
114
+ init_logging (log_level = awscrt_log_level , file_name = "stdout" )
115
+
98
116
# initializing public attributes
99
117
self .endpoint = self ._normalize_iot_credential_endpoint (
100
118
endpoint = endpoint
101
119
)
102
120
self .role_alias = role_alias
103
- self .certificate = certificate
121
+ self .certificate = self ._read_maybe_path_to_bytes (
122
+ certificate , fallback = None , name = "certificate"
123
+ )
104
124
self .thing_name = thing_name
105
- self .private_key = private_key
106
- self .pkcs11 = pkcs11
107
- self .ca = ca
125
+ self .private_key = self ._read_maybe_path_to_bytes (
126
+ private_key , fallback = None , name = "private_key"
127
+ )
128
+ self .pkcs11 = self ._validate_pkcs11 (pkcs11 ) if pkcs11 else None
129
+ self .ca = self ._read_maybe_path_to_bytes (ca , fallback = None , name = "ca" )
108
130
self .verify_peer = verify_peer
109
131
self .timeout = 10.0 if timeout is None else timeout
110
132
self .duration_seconds = duration_seconds
111
133
112
- # loading X.509 certificate if presented as a string, which
113
- # is presumed to be the file path.
114
- # if presented as bytes then self.certificate is presumed to be
115
- # the actual certificate itself
116
- if self .certificate and isinstance (self .certificate , str ):
117
- self .certificate = (
118
- Path (self .certificate ).expanduser ().resolve ().read_bytes ()
119
- )
120
-
121
134
# either private_key or pkcs11 must be provided
122
135
if self .private_key is None and self .pkcs11 is None :
123
136
raise BRSError (
@@ -130,22 +143,6 @@ def __init__(
130
143
"Only one of 'private_key' or 'pkcs11' can be provided."
131
144
)
132
145
133
- # if the provided private_key is bytes then it's presumed to be
134
- # the actual private key. but if it's string then it's presumed
135
- # to be the file path
136
- if self .private_key and isinstance (self .private_key , str ):
137
- self .private_key = (
138
- Path (self .private_key ).expanduser ().resolve ().read_bytes ()
139
- )
140
-
141
- # verifying PKCS#11 dict
142
- if self .pkcs11 :
143
- self .pkcs11 = self ._validate_pkcs11 (pkcs11 = self .pkcs11 )
144
-
145
- # ca is like many other attributes in that str implies file location
146
- if self .ca and isinstance (self .ca , str ):
147
- self .ca = Path (self .ca ).expanduser ().resolve ().read_bytes ()
148
-
149
146
def _get_credentials (self ) -> TemporaryCredentials :
150
147
url = urlparse (
151
148
f"https://{ self .endpoint } /role-aliases/{ self .role_alias } "
@@ -334,3 +331,208 @@ def _validate_pkcs11(pkcs11: PKCS11) -> PKCS11:
334
331
pkcs11 .setdefault ("token_label" , None )
335
332
pkcs11 .setdefault ("private_key_label" , None )
336
333
return pkcs11
334
+
335
+ @staticmethod
336
+ def _read_maybe_path_to_bytes (
337
+ v : str | bytes | None , fallback : bytes | None , name : str
338
+ ) -> bytes | None :
339
+ match v :
340
+ case None :
341
+ return fallback
342
+ case bytes ():
343
+ return v
344
+ case str () as p if Path (p ).expanduser ().resolve ().is_file ():
345
+ return Path (p ).expanduser ().resolve ().read_bytes ()
346
+ case _:
347
+ raise BRSError (f"Invalid { name } provided." )
348
+
349
+ @staticmethod
350
+ def _bytes_to_tempfile (b : bytes , suffix : str = ".pem" ) -> str :
351
+ f = NamedTemporaryFile ("wb" , suffix = suffix , delete = False )
352
+ f .write (b )
353
+ f .flush ()
354
+ f .close ()
355
+ _TEMP_PATHS .append (f .name )
356
+ return f .name
357
+
358
+ @staticmethod
359
+ @register
360
+ def _cleanup_tempfiles ():
361
+ for p in _TEMP_PATHS :
362
+ try :
363
+ Path (p ).unlink (missing_ok = True )
364
+ except Exception :
365
+ ...
366
+
367
+ def mqtt (
368
+ self ,
369
+ * ,
370
+ endpoint : str ,
371
+ client_id : str ,
372
+ transport : Transport = "x509" ,
373
+ certificate : str | bytes | None = None ,
374
+ private_key : str | bytes | None = None ,
375
+ ca : str | bytes | None = None ,
376
+ pkcs11 : PKCS11 | None = None ,
377
+ region : str | None = None ,
378
+ keep_alive_secs : int = 60 ,
379
+ clean_start : bool = True ,
380
+ port : int | None = None ,
381
+ use_alpn : bool = False ,
382
+ ) -> Connection :
383
+ """Establishes an MQTT connection using the specified parameters.
384
+
385
+ Parameters
386
+ ----------
387
+ endpoint: str
388
+ The MQTT endpoint to connect to.
389
+ client_id: str
390
+ The client ID to use for the MQTT connection.
391
+ transport: Transport
392
+ The transport protocol to use (e.g., "x509" or "ws").
393
+ certificate: str | bytes | None
394
+ The client certificate to use for the connection. Defaults to the
395
+ session certificate.
396
+ private_key: str | bytes | None
397
+ The private key to use for the connection. Defaults to the
398
+ session private key.
399
+ ca: str | bytes | None
400
+ The CA certificate to use for the connection. Defaults to the
401
+ session CA certificate.
402
+ pkcs11: PKCS11 | None
403
+ PKCS#11 configuration for hardware-backed keys. Defaults to the
404
+ session PKCS#11 configuration.
405
+ region: str | None
406
+ The AWS region to use for the connection. Defaults to the
407
+ session region.
408
+ keep_alive_secs: int
409
+ The keep-alive interval for the MQTT connection. Default is 60
410
+ seconds.
411
+ clean_start: bool
412
+ Whether to start a clean session. Default is True.
413
+ port: int | None
414
+ The port to use for the MQTT connection. Default is 8883 if not
415
+ using ALPN, otherwise 443.
416
+ use_alpn: bool
417
+ Whether to use ALPN for the connection. Default is False.
418
+
419
+ Returns
420
+ -------
421
+ awscrt.mqtt.Connection
422
+ The established MQTT connection.
423
+ """
424
+
425
+ # Validate transport
426
+ if transport not in list (get_args (Transport )):
427
+ raise BRSError ("Transport must be 'x509' or 'ws'" )
428
+
429
+ # Region default (WS only)
430
+ if region is None :
431
+ region = self .region_name
432
+
433
+ # Normalize inputs to bytes using session defaults
434
+ cert_bytes = self ._read_maybe_path_to_bytes (
435
+ certificate , getattr (self , "certificate" , None ), "certificate"
436
+ )
437
+ key_bytes = self ._read_maybe_path_to_bytes (
438
+ private_key , getattr (self , "private_key" , None ), "private_key"
439
+ )
440
+ ca_bytes = self ._read_maybe_path_to_bytes (
441
+ ca , getattr (self , "ca" , None ), "ca"
442
+ )
443
+
444
+ # Validate PKCS#11
445
+ match pkcs11 :
446
+ case None :
447
+ pkcs11 = getattr (self , "pkcs11" , None )
448
+ case dict ():
449
+ pkcs11 = self ._validate_pkcs11 (pkcs11 )
450
+ case _:
451
+ raise BRSError ("Invalid PKCS#11 configuration provided." )
452
+
453
+ # X.509 invariants
454
+ if transport == "x509" :
455
+ has_key = key_bytes is not None
456
+ has_hsm = pkcs11 is not None
457
+ if not has_key and not has_hsm :
458
+ raise BRSError (
459
+ "For transport='x509', provide either 'private_key' "
460
+ "(bytes/path) or 'pkcs11'."
461
+ )
462
+ if has_key and has_hsm :
463
+ raise BRSError (
464
+ "Provide only one of 'private_key' or 'pkcs11' for "
465
+ "transport='x509'."
466
+ )
467
+ if cert_bytes is None :
468
+ raise BRSError ("Certificate is required for transport='x509'" )
469
+
470
+ # CRT bootstrap
471
+ event_loop = io .EventLoopGroup (1 )
472
+ host_resolver = io .DefaultHostResolver (event_loop )
473
+ bootstrap = io .ClientBootstrap (event_loop , host_resolver )
474
+
475
+ # Build connection
476
+ if transport == "x509" :
477
+ if pkcs11 is not None :
478
+ # Cert must be a filepath for PKCS#11 builder → write temp
479
+ cert_path = self ._bytes_to_tempfile (
480
+ cast (bytes , cert_bytes ), ".crt"
481
+ )
482
+ ca_path = (
483
+ self ._bytes_to_tempfile (ca_bytes , ".pem" )
484
+ if ca_bytes
485
+ else None
486
+ )
487
+
488
+ return mqtt_connection_builder .mtls_with_pkcs11 (
489
+ endpoint = endpoint ,
490
+ client_bootstrap = bootstrap ,
491
+ pkcs11_lib = Pkcs11Lib (file = pkcs11 ["pkcs11_lib" ]),
492
+ user_pin = pkcs11 .get ("user_pin" ),
493
+ slot_id = pkcs11 .get ("slot_id" ),
494
+ token_label = pkcs11 .get ("token_label" ),
495
+ private_key_object = pkcs11 .get ("private_key_label" ),
496
+ cert_filepath = cert_path ,
497
+ ca_filepath = ca_path ,
498
+ client_id = client_id ,
499
+ clean_session = clean_start ,
500
+ keep_alive_secs = keep_alive_secs ,
501
+ port = port or (443 if use_alpn else 8883 ),
502
+ alpn_list = ["x-amzn-mqtt-ca" ] if use_alpn else None ,
503
+ )
504
+ else :
505
+ # pure mTLS with in-memory cert/key/CA
506
+ return mqtt_connection_builder .mtls_from_bytes (
507
+ endpoint = endpoint ,
508
+ cert_bytes = cert_bytes ,
509
+ pri_key_bytes = key_bytes ,
510
+ ca_bytes = ca_bytes ,
511
+ client_bootstrap = bootstrap ,
512
+ client_id = client_id ,
513
+ clean_session = clean_start ,
514
+ keep_alive_secs = keep_alive_secs ,
515
+ port = port or (443 if use_alpn else 8883 ),
516
+ alpn_list = ["x-amzn-mqtt-ca" ] if use_alpn else None ,
517
+ )
518
+
519
+ else : # transport == "ws"
520
+ # WebSockets + SigV4
521
+ creds_provider = auth .AwsCredentialsProvider .new_delegate (
522
+ self ._credentials
523
+ )
524
+ ca_path = (
525
+ self ._bytes_to_tempfile (ca_bytes , ".pem" ) if ca_bytes else None
526
+ )
527
+
528
+ return mqtt_connection_builder .websockets_with_default_aws_signing (
529
+ endpoint = endpoint ,
530
+ client_bootstrap = bootstrap ,
531
+ region = region ,
532
+ credentials_provider = creds_provider ,
533
+ client_id = client_id ,
534
+ clean_session = clean_start ,
535
+ keep_alive_secs = keep_alive_secs ,
536
+ ca_filepath = ca_path ,
537
+ port = port or 443 ,
538
+ )
0 commit comments