From 35a8b4315bfad78df4f1707be1bbf04da1574f13 Mon Sep 17 00:00:00 2001 From: Sudip Sinha Date: Tue, 16 Jun 2026 08:39:59 +0100 Subject: [PATCH] fix(drift): derive fitColumns from stored metadata when omitted The Python service rejected drift metric requests with HTTP 400 when fitColumns was not provided. The Java service treats fitColumns as optional, fitting on all input columns when omitted. Derive fitColumns from the model's stored input schema metadata across all 6 drift endpoint functions (3 compute + 3 schedule). Co-Authored-By: Claude Opus 4.6 Signed-off-by: Sudip Sinha --- src/endpoints/metrics/drift/compare_means.py | 54 +++++++++++-------- src/endpoints/metrics/drift/jensen_shannon.py | 20 ++++--- .../metrics/drift/kolmogorov_smirnov.py | 20 +++++-- tests/endpoints/metrics/drift/factory.py | 19 +++---- .../metrics/drift/test_compare_means.py | 24 ++++----- .../metrics/drift/test_jensen_shannon.py | 24 ++++----- .../metrics/drift/test_kolmogorov_smirnov.py | 24 ++++----- 7 files changed, 108 insertions(+), 77 deletions(-) diff --git a/src/endpoints/metrics/drift/compare_means.py b/src/endpoints/metrics/drift/compare_means.py index 2b9a7e9a..6f0df177 100644 --- a/src/endpoints/metrics/drift/compare_means.py +++ b/src/endpoints/metrics/drift/compare_means.py @@ -103,18 +103,22 @@ async def compute_compare_means( ) if not request.fit_columns: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, - detail="fitColumns is required - specify which features to test for drift", - ) - - # Validate feature names are not blank/whitespace - valid_features = [f.strip() for f in request.fit_columns if f.strip()] - if not valid_features: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, - detail="fitColumns must contain at least one non-empty feature name", + data_source = get_data_source() + metadata = await data_source.get_metadata(request.model_id) + request.fit_columns = list(metadata.input_schema.items.keys()) + logger.info( + "fitColumns not specified, using all input columns for model %s: %s", + request.model_id, + request.fit_columns, ) + else: + valid_features = [f.strip() for f in request.fit_columns if f.strip()] + if not valid_features: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail="fitColumns must contain at least one non-empty feature name", + ) + request.fit_columns = valid_features try: logger.info("Computing %s for model: %s", METRIC_NAME, request.model_id) @@ -153,7 +157,7 @@ async def compute_compare_means( # Multi-feature case: iterate over features results = {} - for feature_name in valid_features: + for feature_name in request.fit_columns: if ( feature_name not in reference_df.columns or feature_name not in current_df.columns @@ -254,18 +258,22 @@ async def schedule_compare_means(request: CompareMeansMetricRequest) -> dict[str ) if not request.fit_columns: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, - detail="fitColumns is required - specify which features to test for drift", - ) - - # Validate feature names are not blank/whitespace - valid_features = [f.strip() for f in request.fit_columns if f.strip()] - if not valid_features: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, - detail="fitColumns must contain at least one non-empty feature name", + data_source = get_data_source() + metadata = await data_source.get_metadata(request.model_id) + request.fit_columns = list(metadata.input_schema.items.keys()) + logger.info( + "fitColumns not specified, using all input columns for model %s: %s", + request.model_id, + request.fit_columns, ) + else: + valid_features = [f.strip() for f in request.fit_columns if f.strip()] + if not valid_features: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail="fitColumns must contain at least one non-empty feature name", + ) + request.fit_columns = valid_features try: # Generate UUID for this request diff --git a/src/endpoints/metrics/drift/jensen_shannon.py b/src/endpoints/metrics/drift/jensen_shannon.py index 36cef0f2..7977fe83 100644 --- a/src/endpoints/metrics/drift/jensen_shannon.py +++ b/src/endpoints/metrics/drift/jensen_shannon.py @@ -106,9 +106,13 @@ async def compute_jensenshannon( ) if not request.fit_columns: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, - detail="fitColumns is required - specify which features to test for drift", + data_source = get_data_source() + metadata = await data_source.get_metadata(request.model_id) + request.fit_columns = list(metadata.input_schema.items.keys()) + logger.info( + "fitColumns not specified, using all input columns for model %s: %s", + request.model_id, + request.fit_columns, ) try: @@ -231,9 +235,13 @@ async def schedule_jensenshannon(request: JensenShannonMetricRequest) -> dict[st ) if not request.fit_columns: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, - detail="fitColumns is required - specify which features to test for drift", + data_source = get_data_source() + metadata = await data_source.get_metadata(request.model_id) + request.fit_columns = list(metadata.input_schema.items.keys()) + logger.info( + "fitColumns not specified, using all input columns for model %s: %s", + request.model_id, + request.fit_columns, ) # Get the scheduler and validate availability diff --git a/src/endpoints/metrics/drift/kolmogorov_smirnov.py b/src/endpoints/metrics/drift/kolmogorov_smirnov.py index 7875fa02..bda435d3 100644 --- a/src/endpoints/metrics/drift/kolmogorov_smirnov.py +++ b/src/endpoints/metrics/drift/kolmogorov_smirnov.py @@ -83,9 +83,13 @@ async def compute_kstest( ) if not request.fit_columns: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, - detail="fitColumns is required - specify which features to test for drift", + data_source = get_data_source() + metadata = await data_source.get_metadata(request.model_id) + request.fit_columns = list(metadata.input_schema.items.keys()) + logger.info( + "fitColumns not specified, using all input columns for model %s: %s", + request.model_id, + request.fit_columns, ) try: @@ -186,6 +190,16 @@ async def get_kstest_definition() -> dict[str, str]: @router.post("/metrics/drift/kstest/request") async def schedule_kstest(request: KSTestMetricRequest) -> dict[str, str]: """Schedule a recurring computation of KSTest metric.""" + if not request.fit_columns: + data_source = get_data_source() + metadata = await data_source.get_metadata(request.model_id) + request.fit_columns = list(metadata.input_schema.items.keys()) + logger.info( + "fitColumns not specified, using all input columns for model %s: %s", + request.model_id, + request.fit_columns, + ) + # Get the scheduler and validate availability scheduler = get_prometheus_scheduler() if not scheduler: diff --git a/tests/endpoints/metrics/drift/factory.py b/tests/endpoints/metrics/drift/factory.py index bd680689..95c37c9d 100644 --- a/tests/endpoints/metrics/drift/factory.py +++ b/tests/endpoints/metrics/drift/factory.py @@ -73,16 +73,17 @@ def make_compute_endpoint_test( def test_impl(_: object, mock_ds: MagicMock) -> None: """Test compute endpoint returns valid response structure.""" # Create sample dataframe (Pandas or Polars based on df_type) - sample_df = _create_sample_dataframe( - request_payload.get("fitColumns", ["feature1"]), - df_type=df_type, - ) + columns = request_payload.get("fitColumns", ["feature1"]) + sample_df = _create_sample_dataframe(columns, df_type=df_type) # Mock data source mock_data_source = MagicMock() mock_data_source.get_dataframe_by_tag = AsyncMock(return_value=sample_df) mock_data_source.get_organic_dataframe = AsyncMock(return_value=sample_df) mock_data_source.get_dataframe = AsyncMock(return_value=sample_df) + mock_metadata = MagicMock() + mock_metadata.input_schema.items.keys.return_value = columns + mock_data_source.get_metadata = AsyncMock(return_value=mock_metadata) mock_ds.return_value = mock_data_source # Send request @@ -338,17 +339,17 @@ def make_compute_endpoint_error_test( def test_impl(_: object, mock_ds: MagicMock) -> None: """Test compute endpoint error handling.""" if setup_mocks: - # Create sample dataframe - sample_df = _create_sample_dataframe( - ["feature1", "feature2"], - df_type=df_type, - ) + columns = ["feature1", "feature2"] + sample_df = _create_sample_dataframe(columns, df_type=df_type) # Mock data source mock_data_source = MagicMock() mock_data_source.get_dataframe_by_tag = AsyncMock(return_value=sample_df) mock_data_source.get_organic_dataframe = AsyncMock(return_value=sample_df) mock_data_source.get_dataframe = AsyncMock(return_value=sample_df) + mock_metadata = MagicMock() + mock_metadata.input_schema.items.keys.return_value = columns + mock_data_source.get_metadata = AsyncMock(return_value=mock_metadata) mock_ds.return_value = mock_data_source # Send request diff --git a/tests/endpoints/metrics/drift/test_compare_means.py b/tests/endpoints/metrics/drift/test_compare_means.py index 7ac3ab5b..35d04763 100644 --- a/tests/endpoints/metrics/drift/test_compare_means.py +++ b/tests/endpoints/metrics/drift/test_compare_means.py @@ -292,18 +292,18 @@ def test_multi_feature_drift_consistency_deterministic(self) -> None: expected_error_substring="referenceTag is required", ) - test_compute_missing_fit_columns = factory.make_compute_endpoint_error_test( - metric_name="CompareMeans", - module_path="src.endpoints.metrics.drift.compare_means", - endpoint_path="/metrics/drift/comparemeans", - client=client, - request_payload={ - "modelId": "test-model", - "referenceTag": "baseline", - # Missing fitColumns - }, - expected_status_code=HTTPStatus.BAD_REQUEST, - expected_error_substring="fitColumns is required", + test_compute_missing_fit_columns_derives_from_metadata = ( + factory.make_compute_endpoint_test( + metric_name="CompareMeans", + module_path="src.endpoints.metrics.drift.compare_means", + endpoint_path="/metrics/drift/comparemeans", + client=client, + request_payload={ + "modelId": "test-model", + "referenceTag": "baseline", + }, + expected_response_keys=["status", "value", "drift_detected"], + ) ) test_compute_invalid_feature = factory.make_compute_endpoint_error_test( diff --git a/tests/endpoints/metrics/drift/test_jensen_shannon.py b/tests/endpoints/metrics/drift/test_jensen_shannon.py index a86d2901..f474d970 100644 --- a/tests/endpoints/metrics/drift/test_jensen_shannon.py +++ b/tests/endpoints/metrics/drift/test_jensen_shannon.py @@ -129,18 +129,18 @@ class TestJensenShannonEndpoints: expected_error_substring="referenceTag is required", ) - test_compute_missing_fit_columns = factory.make_compute_endpoint_error_test( - metric_name="JensenShannon", - module_path="src.endpoints.metrics.drift.jensen_shannon", - endpoint_path="/metrics/drift/jensenshannon", - client=client, - request_payload={ - "modelId": "test-model", - "referenceTag": "baseline", - # Missing fitColumns - }, - expected_status_code=HTTPStatus.BAD_REQUEST, - expected_error_substring="fitColumns is required", + test_compute_missing_fit_columns_derives_from_metadata = ( + factory.make_compute_endpoint_test( + metric_name="JensenShannon", + module_path="src.endpoints.metrics.drift.jensen_shannon", + endpoint_path="/metrics/drift/jensenshannon", + client=client, + request_payload={ + "modelId": "test-model", + "referenceTag": "baseline", + }, + expected_response_keys=["status", "value", "drift_detected"], + ) ) test_compute_invalid_feature = factory.make_compute_endpoint_error_test( diff --git a/tests/endpoints/metrics/drift/test_kolmogorov_smirnov.py b/tests/endpoints/metrics/drift/test_kolmogorov_smirnov.py index 9de2d5d6..08662351 100644 --- a/tests/endpoints/metrics/drift/test_kolmogorov_smirnov.py +++ b/tests/endpoints/metrics/drift/test_kolmogorov_smirnov.py @@ -114,18 +114,18 @@ class TestKSTestEndpoints: expected_error_substring="referenceTag is required", ) - test_compute_missing_fit_columns = factory.make_compute_endpoint_error_test( - metric_name="KSTest", - module_path="src.endpoints.metrics.drift.kolmogorov_smirnov", - endpoint_path="/metrics/drift/kstest", - client=client, - request_payload={ - "modelId": "test-model", - "referenceTag": "baseline", - # Missing fitColumns - }, - expected_status_code=HTTPStatus.BAD_REQUEST, - expected_error_substring="fitColumns is required", + test_compute_missing_fit_columns_derives_from_metadata = ( + factory.make_compute_endpoint_test( + metric_name="KSTest", + module_path="src.endpoints.metrics.drift.kolmogorov_smirnov", + endpoint_path="/metrics/drift/kstest", + client=client, + request_payload={ + "modelId": "test-model", + "referenceTag": "baseline", + }, + expected_response_keys=["status", "value", "drift_detected"], + ) ) test_compute_invalid_feature = factory.make_compute_endpoint_error_test(