Skip to content

Commit 13f3362

Browse files
committed
feat: add TLS support
1 parent a814004 commit 13f3362

File tree

1 file changed

+68
-43
lines changed

1 file changed

+68
-43
lines changed

src/arduino/app_peripherals/camera/websocket_camera.py

Lines changed: 68 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from arduino.app_utils import Logger
2020

2121
from .base_camera import BaseCamera
22-
from .errors import CameraOpenError
22+
from .errors import CameraConfigError, CameraOpenError
2323

2424
logger = Logger("WebSocketCamera")
2525

@@ -41,16 +41,18 @@ class WebSocketCamera(BaseCamera):
4141
4242
Secure communication with the WebSocket server is supported in three security modes:
4343
- Security disabled (empty secret)
44-
- Authenticated (secret + enable_encryption=False) - HMAC-SHA256
45-
- Authenticated + Encrypted (secret + enable_encryption=True) - ChaCha20-Poly1305
44+
- Authenticated (secret + encrypt=False) - HMAC-SHA256
45+
- Authenticated + Encrypted (secret + encrypt=True) - ChaCha20-Poly1305
4646
"""
4747

4848
def __init__(
4949
self,
5050
port: int = 8080,
5151
timeout: int = 3,
52+
certs_dir_path: str = "/app/certs",
53+
use_tls: bool = False,
5254
secret: str = "",
53-
enable_encryption: bool = False,
55+
encrypt: bool = False,
5456
resolution: tuple[int, int] = (640, 480),
5557
fps: int = 10,
5658
adjustments: Callable[[np.ndarray], np.ndarray] | None = None,
@@ -62,26 +64,54 @@ def __init__(
6264
Args:
6365
port (int): Port to bind the server to
6466
timeout (int): Connection timeout in seconds
67+
certs_dir_path (str): Path to the directory containing TLS certificates
68+
use_tls (bool): Enable TLS for secure connections. If True, 'encrypt' will
69+
be ignored. Use this for transport-level security with clients that can
70+
accept self-signed certificates or when supplying your own certificates.
6571
secret (str): Secret key for authentication/encryption (empty = security disabled)
66-
enable_encryption (bool): Enable encryption (only effective if secret is provided)
72+
encrypt (bool): Enable encryption (only effective if secret is provided)
6773
resolution (tuple[int, int]): Resolution as (width, height)
6874
fps (int): Frames per second to capture
6975
adjustments (Callable[[np.ndarray], np.ndarray] | None): Function to adjust frames
7076
auto_reconnect (bool): Enable automatic reconnection on failure
7177
"""
7278
super().__init__(resolution, fps, adjustments, auto_reconnect)
7379

74-
self.protocol = "ws"
75-
self.port = port
76-
self.timeout = timeout
77-
self.codec = BPPCodec(secret, enable_encryption)
80+
if use_tls and encrypt:
81+
logger.warning("Encryption is redundant over TLS connections, disabling encryption.")
82+
encrypt = False
83+
84+
self.codec = BPPCodec(secret, encrypt)
7885
self.secret = secret
79-
self.enable_encryption = enable_encryption
86+
self.encrypt = encrypt
8087
self.logger = logger
88+
self.name = self.__class__.__name__
8189

82-
host_ip = os.getenv("HOST_IP")
90+
# Address and port configuration
91+
self.use_tls = use_tls
92+
self.protocol = "wss" if use_tls else "ws"
8393
self._bind_ip = "0.0.0.0"
94+
host_ip = os.getenv("HOST_IP")
8495
self.ip = host_ip if host_ip is not None else self._bind_ip
96+
if port < 0 or port > 65535:
97+
raise CameraConfigError(f"Invalid port number: {port}")
98+
self.port = port
99+
if timeout <= 0:
100+
raise CameraConfigError(f"Invalid timeout value: {timeout}")
101+
self.timeout = timeout
102+
103+
# TLS configuration
104+
if self.use_tls:
105+
import ssl
106+
from arduino.app_utils.tls_cert_manager import TLSCertificateManager
107+
108+
try:
109+
cert_path, key_path = TLSCertificateManager.get_or_create_certificates(certs_dir=certs_dir_path)
110+
self._ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
111+
self._ssl_context.load_cert_chain(cert_path, key_path)
112+
logger.info(f"SSL context created with certificate: {cert_path}")
113+
except Exception as e:
114+
raise RuntimeError("Failed to configure TLS certificate. Please check certificates and the certs directory.") from e
85115

86116
self._frame_queue = queue.Queue(1)
87117
self._server = None
@@ -101,7 +131,7 @@ def security_mode(self) -> str:
101131
"""Return current security mode for logging/debugging."""
102132
if not self.secret:
103133
return "none"
104-
elif self.enable_encryption:
134+
elif self.encrypt:
105135
return "encrypted (ChaCha20-Poly1305)"
106136
else:
107137
return "authenticated (HMAC-SHA256)"
@@ -136,12 +166,11 @@ def _start_server_thread(self) -> None:
136166
finally:
137167
if self._loop and not self._loop.is_closed():
138168
self._loop.close()
169+
self._loop = None
139170

140171
async def _start_server(self) -> None:
141172
"""Start the WebSocket server."""
142173
try:
143-
self._stop_event.clear()
144-
145174
self._server = await asyncio.wait_for(
146175
websockets.serve(
147176
self._ws_handler,
@@ -152,6 +181,7 @@ async def _start_server(self) -> None:
152181
close_timeout=self.timeout,
153182
ping_interval=20,
154183
max_size=5 * 1024 * 1024, # Limit max message size for security
184+
ssl=self._ssl_context if self.use_tls else None,
155185
),
156186
timeout=self.timeout,
157187
)
@@ -161,7 +191,7 @@ async def _start_server(self) -> None:
161191
server_socket = list(self._server.sockets)[0]
162192
self.port = server_socket.getsockname()[1]
163193

164-
await self._stop_event.wait()
194+
await self._server.wait_closed()
165195

166196
except TimeoutError as e:
167197
self.logger.error(f"Failed to start WebSocket server within {self.timeout}s: {e}")
@@ -170,10 +200,7 @@ async def _start_server(self) -> None:
170200
self.logger.error(f"Failed to start WebSocket server: {e}")
171201
raise
172202
finally:
173-
if self._server:
174-
self._server.close()
175-
await self._server.wait_closed()
176-
self._server = None
203+
self._server = None
177204

178205
async def _ws_handler(self, conn: websockets.ServerConnection) -> None:
179206
"""Handle a connected WebSocket client. Only one client allowed at a time."""
@@ -200,14 +227,14 @@ async def _ws_handler(self, conn: websockets.ServerConnection) -> None:
200227
try:
201228
# Send welcome message
202229
try:
203-
welcome = json.dumps({
230+
welcome = {
204231
"status": "connected",
205232
"message": "Connected to camera server",
206233
"security_mode": self.security_mode,
207234
"resolution": self.resolution,
208235
"fps": self.fps,
209-
})
210-
await self._send_to_client(welcome)
236+
}
237+
await self._send_to_client(json.dumps(welcome))
211238
except Exception as e:
212239
self.logger.warning(f"Failed to send welcome message: {e}")
213240

@@ -241,6 +268,7 @@ async def _ws_handler(self, conn: websockets.ServerConnection) -> None:
241268
self.logger.debug(f"Client removed: {client_addr}")
242269

243270
def _parse_message(self, message: websockets.Data) -> np.ndarray | None:
271+
"""Parse WebSocket message to extract a video frame."""
244272
if isinstance(message, str):
245273
try:
246274
message = base64.b64decode(message)
@@ -261,7 +289,7 @@ def _close_camera(self):
261289
"""Stop the WebSocket server."""
262290
if self._loop and not self._loop.is_closed() and self._loop.is_running():
263291
try:
264-
future = asyncio.run_coroutine_threadsafe(self._stop_and_disconnect_client(), self._loop)
292+
future = asyncio.run_coroutine_threadsafe(self._disconnect_and_stop(), self._loop)
265293
future.result(1.0)
266294
except CancelledError:
267295
self.logger.debug(f"Error stopping WebSocket server: CancelledError")
@@ -281,26 +309,23 @@ def _close_camera(self):
281309
except queue.Empty:
282310
pass
283311

284-
# Reset state
285-
self._server = None
286-
self._loop = None
287-
self._client = None
288-
289-
async def _stop_and_disconnect_client(self):
290-
"""Cleanly disconnect client with goodbye message."""
291-
if self._client:
292-
try:
293-
self.logger.debug("Disconnecting client...")
294-
goodbye = json.dumps({"status": "disconnecting", "message": "Server is shutting down"})
295-
await self._send_to_client(goodbye)
296-
except Exception as e:
297-
self.logger.warning(f"Failed to send goodbye message: {e}")
298-
finally:
299-
if self._client:
300-
await self._client.close()
301-
self.logger.debug("Client connection closed")
302-
303-
self._stop_event.set()
312+
async def _disconnect_and_stop(self):
313+
"""Cleanly disconnect client with goodbye message and stop the server."""
314+
async with self._client_lock:
315+
if self._client:
316+
try:
317+
self.logger.debug("Disconnecting client...")
318+
goodbye = json.dumps({"status": "disconnecting", "message": "Server is shutting down"})
319+
await self._send_to_client(goodbye)
320+
except Exception as e:
321+
self.logger.warning(f"Failed to send goodbye message: {e}")
322+
finally:
323+
if self._client:
324+
await self._client.close()
325+
self.logger.debug("Client connection closed")
326+
327+
if self._server:
328+
self._server.close()
304329

305330
def _read_frame(self) -> np.ndarray | None:
306331
"""Read a single frame from the queue."""

0 commit comments

Comments
 (0)