Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion openapi_core/contrib/starlette/middlewares.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""OpenAPI core contrib starlette middlewares module"""

from typing import Type

from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.base import RequestResponseEndpoint
from starlette.requests import Request
Expand All @@ -14,14 +16,26 @@
StarletteOpenAPIValidRequestHandler,
)
from openapi_core.contrib.starlette.integrations import StarletteIntegration
from openapi_core.contrib.starlette.requests import StarletteOpenAPIRequest
from openapi_core.contrib.starlette.responses import StarletteOpenAPIResponse


class StarletteOpenAPIMiddleware(StarletteIntegration, BaseHTTPMiddleware):
valid_request_handler_cls = StarletteOpenAPIValidRequestHandler
errors_handler = StarletteOpenAPIErrorsHandler()

def __init__(self, app: ASGIApp, openapi: OpenAPI):
def __init__(
self,
app: ASGIApp,
openapi: OpenAPI,
request_cls: Type[StarletteOpenAPIRequest] = StarletteOpenAPIRequest,
response_cls: Type[
StarletteOpenAPIResponse
] = StarletteOpenAPIResponse,
):
super().__init__(openapi)
self.request_cls = request_cls
self.response_cls = response_cls
BaseHTTPMiddleware.__init__(self, app)

async def dispatch(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from starletteproject.pets.endpoints import pet_detail_endpoint
from starletteproject.pets.endpoints import pet_list_endpoint
from starletteproject.pets.endpoints import pet_photo_endpoint
from starletteproject.tags.endpoints import tag_list_endpoint

from openapi_core.contrib.starlette.middlewares import (
StarletteOpenAPIMiddleware,
Expand All @@ -16,17 +17,30 @@
openapi=openapi,
),
]
middleware_skip_response = [
Middleware(
StarletteOpenAPIMiddleware,
openapi=openapi,
response_cls=None,
),
]

routes = [
Route("/v1/pets", pet_list_endpoint, methods=["GET", "POST"]),
Route("/v1/pets/{petId}", pet_detail_endpoint, methods=["GET", "POST"]),
Route(
"/v1/pets/{petId}/photo", pet_photo_endpoint, methods=["GET", "POST"]
),
Route("/v1/tags", tag_list_endpoint, methods=["GET"]),
]

app = Starlette(
debug=True,
middleware=middleware,
routes=routes,
)
app_skip_response = Starlette(
debug=True,
middleware=middleware_skip_response,
routes=routes,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from starlette.responses import Response


async def tag_list_endpoint(request):
assert request.scope["openapi"]
assert not request.scope["openapi"].errors
assert request.method == "GET"
headers = {
"X-Rate-Limit": "12",
}
return Response(status_code=201, headers=headers)
73 changes: 62 additions & 11 deletions tests/integration/contrib/starlette/test_starlette_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,18 @@ def project_setup():
sys.path.remove(project_dir)


@pytest.fixture
def app():
from starletteproject.__main__ import app

return app

class BaseTestPetstore:
api_key = "12345"

@pytest.fixture
def client(app):
return TestClient(app, base_url="http://petstore.swagger.io")
@pytest.fixture
def app(self):
from starletteproject.__main__ import app

return app

class BaseTestPetstore:
api_key = "12345"
@pytest.fixture
def client(self, app):
return TestClient(app, base_url="http://petstore.swagger.io")

@property
def api_key_encoded(self):
Expand All @@ -37,6 +35,19 @@ def api_key_encoded(self):
return str(api_key_bytes_enc, "utf8")


class BaseTestPetstoreSkipReponse:

@pytest.fixture
def app(self):
from starletteproject.__main__ import app_skip_response

return app_skip_response

@pytest.fixture
def client(self, app):
return TestClient(app, base_url="http://petstore.swagger.io")


class TestPetListEndpoint(BaseTestPetstore):
def test_get_no_required_param(self, client):
headers = {
Expand Down Expand Up @@ -381,3 +392,43 @@ def test_post_valid(self, client, data_gif):

assert not response.text
assert response.status_code == 201


class TestTagListEndpoint(BaseTestPetstore):

def test_get_invalid(self, client):
headers = {
"Authorization": "Basic testuser",
}

response = client.get(
"/v1/tags",
headers=headers,
)

assert response.status_code == 400
assert response.json() == {
"errors": [
{
"title": "Missing response data",
"status": 400,
"type": "<class 'openapi_core.validation.response.exceptions.MissingData'>",
},
],
}


class TestSkipResponseTagListEndpoint(BaseTestPetstoreSkipReponse):

def test_get_valid(self, client):
headers = {
"Authorization": "Basic testuser",
}

response = client.get(
"/v1/tags",
headers=headers,
)

assert not response.text
assert response.status_code == 201
Loading