Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 44 additions & 16 deletions src/endpoints/consumer/consumer_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@
from typing import Annotated, Never

import numpy as np
from fastapi import APIRouter, Header, HTTPException
from fastapi import APIRouter, Header, HTTPException, Request
from numpy import ndarray
from pydantic import TypeAdapter, ValidationError

from src.endpoints.consumer import (
InferencePartialPayload,
KServeData,
KServeInferenceRequest,
KServeInferenceResponse,
)
from src.endpoints.consumer.gzip_utils import decompress_if_gzip
from src.exceptions import ReconciliationError
from src.service.data.datasources.data_source import DataSource

Expand Down Expand Up @@ -440,23 +442,24 @@ def process_payload(
return np.array(kserve_data.data), column_names


@router.post("/")
async def consume_cloud_event(
_kserve_payload_adapter = TypeAdapter(KServeInferenceRequest | KServeInferenceResponse)


async def process_cloud_event(
payload: KServeInferenceRequest | KServeInferenceResponse,
ce_id: Annotated[str | None, Header()] = None,
ce_id: str | None = None,
tag: str | None = None,
) -> dict[str, str]:
"""Consume KServe v2 payloads from cloud events.
"""Process a KServe payload from a cloud event or internal call.

This endpoint accepts both input (request) and output (response) payloads
from ModelMesh-served models and stores them for reconciliation.
This is the core logic shared by the HTTP endpoint and the upload
endpoint's internal forwarding path.

:param payload: KServe inference request or response
:param ce_id: Cloud event ID from header
:param payload: Parsed KServe inference request or response
:param ce_id: Cloud event ID from header (overrides payload.id)
:param tag: Optional tag to associate with the data
:raises HTTPException: If payload processing fails
"""
# set payload id from cloud event header if present
if ce_id is not None:
payload.id = ce_id

Expand All @@ -466,7 +469,6 @@ async def consume_cloud_event(
detail="Payload requires 'id' field or 'ce-id' header",
)

# get global storage interface
storage_interface = get_global_storage_interface()

try:
Expand All @@ -478,13 +480,11 @@ async def consume_cloud_event(
)
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=msg)
logger.info("KServe Inference Input %s received.", payload.id)
# if a match is found, the payload is auto-deleted from data
partial_output = await storage_interface.get_partial_payload(
payload.id, is_input=False, is_modelmesh=False
)
if partial_output is not None:
if not isinstance(partial_output, KServeInferenceResponse):
# This should never happen - indicates storage interface error
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail="Invalid payload type from storage",
Expand Down Expand Up @@ -516,7 +516,6 @@ async def consume_cloud_event(
)
if partial_input is not None:
if not isinstance(partial_input, KServeInferenceRequest):
# This should never happen - indicates storage interface error
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail="Invalid payload type from storage",
Expand All @@ -532,8 +531,6 @@ async def consume_cloud_event(
"message": f"Output payload {payload.id} processed successfully",
}

# Defensive programming: this should never happen due to type annotation
# but adding explicit fallback for type safety
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail="Payload must be either KServeInferenceRequest or KServeInferenceResponse",
Expand All @@ -542,3 +539,34 @@ async def consume_cloud_event(
except ReconciliationError as e:
logger.exception("Reconciliation failed for payload %s", payload.id)
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e)) from e


@router.post("/")
async def consume_cloud_event(
http_request: Request,
ce_id: Annotated[str | None, Header()] = None,
tag: str | None = None,
) -> dict[str, str]:
"""Consume KServe v2 payloads from cloud events.

Knative Eventing may strip the Content-Encoding header while leaving the
body gzip-compressed, so this endpoint detects gzip by magic bytes and
decompresses before JSON parsing.

:param http_request: Raw HTTP request (body may be gzip-compressed without header)
:param ce_id: Cloud event ID from header
:param tag: Optional tag to associate with the data
:raises HTTPException: If payload processing fails
"""
raw_body = await http_request.body()
body = decompress_if_gzip(raw_body)

try:
payload = _kserve_payload_adapter.validate_json(body)
except ValidationError as e:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=f"Invalid payload: {e}",
) from e

return await process_cloud_event(payload, ce_id=ce_id, tag=tag)
59 changes: 59 additions & 0 deletions src/endpoints/consumer/gzip_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Gzip magic-byte detection for CloudEvent payloads.

Knative Eventing reconstructs HTTP requests via the CloudEvents SDK,
which drops transport headers like Content-Encoding while leaving the
body gzip-compressed. This module detects gzip by magic bytes (0x1F 0x8B)
and decompresses at the application layer — the same fix applied to the
Java service's CloudEventConsumer.decompressIfGzip().
"""

import gzip
import logging
from io import BytesIO

from src.middleware.gzip_middleware import GzipRequestMiddleware

logger = logging.getLogger(__name__)

_GZIP_MAGIC = b"\x1f\x8b"
_GZIP_MAGIC_LEN = len(_GZIP_MAGIC)
_CHUNK_SIZE = 64 * 1024 # 64KB streaming chunks
DEFAULT_MAX_DECOMPRESSED_SIZE = GzipRequestMiddleware.DEFAULT_MAX_SIZE


def decompress_if_gzip(
data: bytes,
max_size: int = DEFAULT_MAX_DECOMPRESSED_SIZE,
) -> bytes:
"""Decompress data if it starts with gzip magic bytes.

Returns the original data unchanged if it is not gzip-compressed
or if decompression fails.

:param data: Raw bytes to check and potentially decompress
:param max_size: Maximum allowed decompressed size in bytes
:return: Decompressed bytes, or original data if not gzip
"""
if len(data) < _GZIP_MAGIC_LEN or data[:_GZIP_MAGIC_LEN] != _GZIP_MAGIC:
return data

try:
decompressed = bytearray()
with BytesIO(data) as bio, gzip.GzipFile(fileobj=bio) as gz:
while True:
chunk = gz.read(_CHUNK_SIZE)
if not chunk:
break
if len(decompressed) + len(chunk) > max_size:
msg = f"Decompressed CloudEvent payload exceeds {max_size} bytes"
raise ValueError(msg)
decompressed.extend(chunk)

logger.debug("Decompressed gzip CloudEvent payload")
return bytes(decompressed)

except (gzip.BadGzipFile, OSError):
logger.warning(
"CloudEvent payload starts with gzip magic bytes but failed to decompress, using raw bytes",
)
return data
6 changes: 3 additions & 3 deletions src/endpoints/data/data_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pydantic import BaseModel

from src.endpoints.consumer import KServeInferenceRequest, KServeInferenceResponse
from src.endpoints.consumer.consumer_endpoint import consume_cloud_event
from src.endpoints.consumer.consumer_endpoint import process_cloud_event
from src.exceptions import ReconciliationError
from src.service.constants import TRUSTYAI_TAG_PREFIX
from src.service.data.model_data import ModelData
Expand Down Expand Up @@ -81,8 +81,8 @@ async def upload(payload: UploadPayload) -> dict[str, str]:
else:
previous_data_points = 0

await consume_cloud_event(payload.response, req_id)
await consume_cloud_event(payload.request, req_id, tag=payload.data_tag)
await process_cloud_event(payload.response, req_id)
await process_cloud_event(payload.request, req_id, tag=payload.data_tag)

model_data = ModelData(payload.model_name)
new_data_points = (await model_data.row_counts())[0]
Expand Down
6 changes: 3 additions & 3 deletions src/middleware/gzip_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ class GzipRequestMiddleware:
removing the Content-Encoding header, and updating Content-Length.
Includes protection against decompression bombs via max_size limit.

Defaults: paths=["/data/upload"], max_size=16MB, fail_on_error=True
Defaults: paths=["*"] (all paths), max_size=16MB, fail_on_error=True
"""

# Default configuration constants
DEFAULT_PATHS = ("/data/upload",) # Tuple to avoid mutable default
DEFAULT_PATHS = ("*",) # All paths: Content-Encoding is a transport-level concern
DEFAULT_ALLOWED_CONTENT_TYPES = (
"application/json",
"application/cloudevents+json",
Expand Down Expand Up @@ -86,7 +86,7 @@ def __init__(

Args:
app: ASGI application
paths: Path patterns to apply (supports wildcards, default: ["/data/upload"])
paths: Path patterns to apply (supports wildcards, default: ["*"] = all paths)
max_size: Max decompressed bytes (default: 16MB)
fail_on_error: Return error on failure vs pass through (default: True)
allowed_content_types: Eligible content types
Expand Down
92 changes: 92 additions & 0 deletions tests/endpoints/consumer/test_cloud_event_gzip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""Tests for Layer 2 gzip decompression in the CloudEvent consumer endpoint.

Knative Eventing strips Content-Encoding headers while leaving the body
gzip-compressed. These tests verify that the CloudEvent endpoint detects
gzip by magic bytes and decompresses before JSON parsing.
"""

import gzip
import json

import pytest

from src.endpoints.consumer.gzip_utils import decompress_if_gzip


class TestDecompressIfGzip:
"""Unit tests for the decompress_if_gzip utility."""

def test_decompresses_gzip_data(self) -> None:
"""Valid gzip data is decompressed."""
original = b'{"inputs": [{"name": "x", "shape": [1], "datatype": "FP32", "data": [1.0]}]}'
compressed = gzip.compress(original)

result = decompress_if_gzip(compressed)

assert result == original

def test_returns_non_gzip_unchanged(self) -> None:
"""Non-gzip data is returned unchanged."""
data = b'{"inputs": [{"name": "x", "shape": [1], "datatype": "FP32", "data": [1.0]}]}'

result = decompress_if_gzip(data)

assert result is data

def test_returns_empty_bytes_unchanged(self) -> None:
"""Empty bytes are returned unchanged."""
result = decompress_if_gzip(b"")

assert result == b""

def test_returns_single_byte_unchanged(self) -> None:
"""Single byte (too short for magic check) is returned unchanged."""
result = decompress_if_gzip(b"\x1f")

assert result == b"\x1f"

def test_invalid_gzip_with_magic_bytes_returns_original(self) -> None:
"""Data starting with gzip magic but not valid gzip returns original."""
fake_gzip = b"\x1f\x8b\x00\x00invalid"

result = decompress_if_gzip(fake_gzip)

assert result == fake_gzip

def test_size_limit_raises_on_decompression_bomb(self) -> None:
"""Exceeding max_size raises ValueError."""
large_data = b"x" * 10_000
compressed = gzip.compress(large_data)

with pytest.raises(ValueError, match="exceeds"):
decompress_if_gzip(compressed, max_size=100)

def test_size_limit_within_bounds_succeeds(self) -> None:
"""Data within max_size decompresses successfully."""
data = b'{"test": true}'
compressed = gzip.compress(data)

result = decompress_if_gzip(compressed, max_size=1024)

assert result == data

def test_preserves_json_fidelity(self) -> None:
"""Decompressed JSON round-trips correctly."""
payload = {
"model_name": "example",
"id": "req-001",
"outputs": [
{
"name": "predict",
"shape": [2, 1],
"datatype": "FP64",
"data": [[0.1], [0.9]],
},
],
}
original = json.dumps(payload).encode()
compressed = gzip.compress(original)

result = decompress_if_gzip(compressed)

assert json.loads(result) == payload
Loading
Loading