diff --git a/src/endpoints/metrics/drift/compare_means.py b/src/endpoints/metrics/drift/compare_means.py index 2b9a7e9a..6f0df177 100644 --- a/src/endpoints/metrics/drift/compare_means.py +++ b/src/endpoints/metrics/drift/compare_means.py @@ -103,18 +103,22 @@ async def compute_compare_means( ) if not request.fit_columns: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, - detail="fitColumns is required - specify which features to test for drift", - ) - - # Validate feature names are not blank/whitespace - valid_features = [f.strip() for f in request.fit_columns if f.strip()] - if not valid_features: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, - detail="fitColumns must contain at least one non-empty feature name", + data_source = get_data_source() + metadata = await data_source.get_metadata(request.model_id) + request.fit_columns = list(metadata.input_schema.items.keys()) + logger.info( + "fitColumns not specified, using all input columns for model %s: %s", + request.model_id, + request.fit_columns, ) + else: + valid_features = [f.strip() for f in request.fit_columns if f.strip()] + if not valid_features: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail="fitColumns must contain at least one non-empty feature name", + ) + request.fit_columns = valid_features try: logger.info("Computing %s for model: %s", METRIC_NAME, request.model_id) @@ -153,7 +157,7 @@ async def compute_compare_means( # Multi-feature case: iterate over features results = {} - for feature_name in valid_features: + for feature_name in request.fit_columns: if ( feature_name not in reference_df.columns or feature_name not in current_df.columns @@ -254,18 +258,22 @@ async def schedule_compare_means(request: CompareMeansMetricRequest) -> dict[str ) if not request.fit_columns: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, - detail="fitColumns is required - specify which features to test for drift", - ) - - # Validate feature names are not blank/whitespace - valid_features = [f.strip() for f in request.fit_columns if f.strip()] - if not valid_features: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, - detail="fitColumns must contain at least one non-empty feature name", + data_source = get_data_source() + metadata = await data_source.get_metadata(request.model_id) + request.fit_columns = list(metadata.input_schema.items.keys()) + logger.info( + "fitColumns not specified, using all input columns for model %s: %s", + request.model_id, + request.fit_columns, ) + else: + valid_features = [f.strip() for f in request.fit_columns if f.strip()] + if not valid_features: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail="fitColumns must contain at least one non-empty feature name", + ) + request.fit_columns = valid_features try: # Generate UUID for this request diff --git a/src/endpoints/metrics/drift/jensen_shannon.py b/src/endpoints/metrics/drift/jensen_shannon.py index 36cef0f2..7977fe83 100644 --- a/src/endpoints/metrics/drift/jensen_shannon.py +++ b/src/endpoints/metrics/drift/jensen_shannon.py @@ -106,9 +106,13 @@ async def compute_jensenshannon( ) if not request.fit_columns: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, - detail="fitColumns is required - specify which features to test for drift", + data_source = get_data_source() + metadata = await data_source.get_metadata(request.model_id) + request.fit_columns = list(metadata.input_schema.items.keys()) + logger.info( + "fitColumns not specified, using all input columns for model %s: %s", + request.model_id, + request.fit_columns, ) try: @@ -231,9 +235,13 @@ async def schedule_jensenshannon(request: JensenShannonMetricRequest) -> dict[st ) if not request.fit_columns: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, - detail="fitColumns is required - specify which features to test for drift", + data_source = get_data_source() + metadata = await data_source.get_metadata(request.model_id) + request.fit_columns = list(metadata.input_schema.items.keys()) + logger.info( + "fitColumns not specified, using all input columns for model %s: %s", + request.model_id, + request.fit_columns, ) # Get the scheduler and validate availability diff --git a/src/endpoints/metrics/drift/kolmogorov_smirnov.py b/src/endpoints/metrics/drift/kolmogorov_smirnov.py index 7875fa02..bda435d3 100644 --- a/src/endpoints/metrics/drift/kolmogorov_smirnov.py +++ b/src/endpoints/metrics/drift/kolmogorov_smirnov.py @@ -83,9 +83,13 @@ async def compute_kstest( ) if not request.fit_columns: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, - detail="fitColumns is required - specify which features to test for drift", + data_source = get_data_source() + metadata = await data_source.get_metadata(request.model_id) + request.fit_columns = list(metadata.input_schema.items.keys()) + logger.info( + "fitColumns not specified, using all input columns for model %s: %s", + request.model_id, + request.fit_columns, ) try: @@ -186,6 +190,16 @@ async def get_kstest_definition() -> dict[str, str]: @router.post("/metrics/drift/kstest/request") async def schedule_kstest(request: KSTestMetricRequest) -> dict[str, str]: """Schedule a recurring computation of KSTest metric.""" + if not request.fit_columns: + data_source = get_data_source() + metadata = await data_source.get_metadata(request.model_id) + request.fit_columns = list(metadata.input_schema.items.keys()) + logger.info( + "fitColumns not specified, using all input columns for model %s: %s", + request.model_id, + request.fit_columns, + ) + # Get the scheduler and validate availability scheduler = get_prometheus_scheduler() if not scheduler: diff --git a/tests/endpoints/metrics/drift/factory.py b/tests/endpoints/metrics/drift/factory.py index bd680689..95c37c9d 100644 --- a/tests/endpoints/metrics/drift/factory.py +++ b/tests/endpoints/metrics/drift/factory.py @@ -73,16 +73,17 @@ def make_compute_endpoint_test( def test_impl(_: object, mock_ds: MagicMock) -> None: """Test compute endpoint returns valid response structure.""" # Create sample dataframe (Pandas or Polars based on df_type) - sample_df = _create_sample_dataframe( - request_payload.get("fitColumns", ["feature1"]), - df_type=df_type, - ) + columns = request_payload.get("fitColumns", ["feature1"]) + sample_df = _create_sample_dataframe(columns, df_type=df_type) # Mock data source mock_data_source = MagicMock() mock_data_source.get_dataframe_by_tag = AsyncMock(return_value=sample_df) mock_data_source.get_organic_dataframe = AsyncMock(return_value=sample_df) mock_data_source.get_dataframe = AsyncMock(return_value=sample_df) + mock_metadata = MagicMock() + mock_metadata.input_schema.items.keys.return_value = columns + mock_data_source.get_metadata = AsyncMock(return_value=mock_metadata) mock_ds.return_value = mock_data_source # Send request @@ -338,17 +339,17 @@ def make_compute_endpoint_error_test( def test_impl(_: object, mock_ds: MagicMock) -> None: """Test compute endpoint error handling.""" if setup_mocks: - # Create sample dataframe - sample_df = _create_sample_dataframe( - ["feature1", "feature2"], - df_type=df_type, - ) + columns = ["feature1", "feature2"] + sample_df = _create_sample_dataframe(columns, df_type=df_type) # Mock data source mock_data_source = MagicMock() mock_data_source.get_dataframe_by_tag = AsyncMock(return_value=sample_df) mock_data_source.get_organic_dataframe = AsyncMock(return_value=sample_df) mock_data_source.get_dataframe = AsyncMock(return_value=sample_df) + mock_metadata = MagicMock() + mock_metadata.input_schema.items.keys.return_value = columns + mock_data_source.get_metadata = AsyncMock(return_value=mock_metadata) mock_ds.return_value = mock_data_source # Send request diff --git a/tests/endpoints/metrics/drift/test_compare_means.py b/tests/endpoints/metrics/drift/test_compare_means.py index 7ac3ab5b..35d04763 100644 --- a/tests/endpoints/metrics/drift/test_compare_means.py +++ b/tests/endpoints/metrics/drift/test_compare_means.py @@ -292,18 +292,18 @@ def test_multi_feature_drift_consistency_deterministic(self) -> None: expected_error_substring="referenceTag is required", ) - test_compute_missing_fit_columns = factory.make_compute_endpoint_error_test( - metric_name="CompareMeans", - module_path="src.endpoints.metrics.drift.compare_means", - endpoint_path="/metrics/drift/comparemeans", - client=client, - request_payload={ - "modelId": "test-model", - "referenceTag": "baseline", - # Missing fitColumns - }, - expected_status_code=HTTPStatus.BAD_REQUEST, - expected_error_substring="fitColumns is required", + test_compute_missing_fit_columns_derives_from_metadata = ( + factory.make_compute_endpoint_test( + metric_name="CompareMeans", + module_path="src.endpoints.metrics.drift.compare_means", + endpoint_path="/metrics/drift/comparemeans", + client=client, + request_payload={ + "modelId": "test-model", + "referenceTag": "baseline", + }, + expected_response_keys=["status", "value", "drift_detected"], + ) ) test_compute_invalid_feature = factory.make_compute_endpoint_error_test( diff --git a/tests/endpoints/metrics/drift/test_jensen_shannon.py b/tests/endpoints/metrics/drift/test_jensen_shannon.py index a86d2901..f474d970 100644 --- a/tests/endpoints/metrics/drift/test_jensen_shannon.py +++ b/tests/endpoints/metrics/drift/test_jensen_shannon.py @@ -129,18 +129,18 @@ class TestJensenShannonEndpoints: expected_error_substring="referenceTag is required", ) - test_compute_missing_fit_columns = factory.make_compute_endpoint_error_test( - metric_name="JensenShannon", - module_path="src.endpoints.metrics.drift.jensen_shannon", - endpoint_path="/metrics/drift/jensenshannon", - client=client, - request_payload={ - "modelId": "test-model", - "referenceTag": "baseline", - # Missing fitColumns - }, - expected_status_code=HTTPStatus.BAD_REQUEST, - expected_error_substring="fitColumns is required", + test_compute_missing_fit_columns_derives_from_metadata = ( + factory.make_compute_endpoint_test( + metric_name="JensenShannon", + module_path="src.endpoints.metrics.drift.jensen_shannon", + endpoint_path="/metrics/drift/jensenshannon", + client=client, + request_payload={ + "modelId": "test-model", + "referenceTag": "baseline", + }, + expected_response_keys=["status", "value", "drift_detected"], + ) ) test_compute_invalid_feature = factory.make_compute_endpoint_error_test( diff --git a/tests/endpoints/metrics/drift/test_kolmogorov_smirnov.py b/tests/endpoints/metrics/drift/test_kolmogorov_smirnov.py index 9de2d5d6..08662351 100644 --- a/tests/endpoints/metrics/drift/test_kolmogorov_smirnov.py +++ b/tests/endpoints/metrics/drift/test_kolmogorov_smirnov.py @@ -114,18 +114,18 @@ class TestKSTestEndpoints: expected_error_substring="referenceTag is required", ) - test_compute_missing_fit_columns = factory.make_compute_endpoint_error_test( - metric_name="KSTest", - module_path="src.endpoints.metrics.drift.kolmogorov_smirnov", - endpoint_path="/metrics/drift/kstest", - client=client, - request_payload={ - "modelId": "test-model", - "referenceTag": "baseline", - # Missing fitColumns - }, - expected_status_code=HTTPStatus.BAD_REQUEST, - expected_error_substring="fitColumns is required", + test_compute_missing_fit_columns_derives_from_metadata = ( + factory.make_compute_endpoint_test( + metric_name="KSTest", + module_path="src.endpoints.metrics.drift.kolmogorov_smirnov", + endpoint_path="/metrics/drift/kstest", + client=client, + request_payload={ + "modelId": "test-model", + "referenceTag": "baseline", + }, + expected_response_keys=["status", "value", "drift_detected"], + ) ) test_compute_invalid_feature = factory.make_compute_endpoint_error_test(