Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 31 additions & 23 deletions src/endpoints/metrics/drift/compare_means.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,18 +103,22 @@ async def compute_compare_means(
)

if not request.fit_columns:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail="fitColumns is required - specify which features to test for drift",
)

# Validate feature names are not blank/whitespace
valid_features = [f.strip() for f in request.fit_columns if f.strip()]
if not valid_features:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail="fitColumns must contain at least one non-empty feature name",
data_source = get_data_source()
metadata = await data_source.get_metadata(request.model_id)
request.fit_columns = list(metadata.input_schema.items.keys())
logger.info(
"fitColumns not specified, using all input columns for model %s: %s",
request.model_id,
request.fit_columns,
)
else:
valid_features = [f.strip() for f in request.fit_columns if f.strip()]
if not valid_features:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail="fitColumns must contain at least one non-empty feature name",
)
request.fit_columns = valid_features

try:
logger.info("Computing %s for model: %s", METRIC_NAME, request.model_id)
Expand Down Expand Up @@ -153,7 +157,7 @@ async def compute_compare_means(

# Multi-feature case: iterate over features
results = {}
for feature_name in valid_features:
for feature_name in request.fit_columns:
if (
feature_name not in reference_df.columns
or feature_name not in current_df.columns
Expand Down Expand Up @@ -254,18 +258,22 @@ async def schedule_compare_means(request: CompareMeansMetricRequest) -> dict[str
)

if not request.fit_columns:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail="fitColumns is required - specify which features to test for drift",
)

# Validate feature names are not blank/whitespace
valid_features = [f.strip() for f in request.fit_columns if f.strip()]
if not valid_features:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail="fitColumns must contain at least one non-empty feature name",
data_source = get_data_source()
metadata = await data_source.get_metadata(request.model_id)
request.fit_columns = list(metadata.input_schema.items.keys())
logger.info(
"fitColumns not specified, using all input columns for model %s: %s",
request.model_id,
request.fit_columns,
)
else:
valid_features = [f.strip() for f in request.fit_columns if f.strip()]
if not valid_features:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail="fitColumns must contain at least one non-empty feature name",
)
request.fit_columns = valid_features

try:
# Generate UUID for this request
Expand Down
20 changes: 14 additions & 6 deletions src/endpoints/metrics/drift/jensen_shannon.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,13 @@ async def compute_jensenshannon(
)

if not request.fit_columns:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail="fitColumns is required - specify which features to test for drift",
data_source = get_data_source()
metadata = await data_source.get_metadata(request.model_id)
request.fit_columns = list(metadata.input_schema.items.keys())
logger.info(
"fitColumns not specified, using all input columns for model %s: %s",
request.model_id,
request.fit_columns,
)

try:
Expand Down Expand Up @@ -231,9 +235,13 @@ async def schedule_jensenshannon(request: JensenShannonMetricRequest) -> dict[st
)

if not request.fit_columns:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail="fitColumns is required - specify which features to test for drift",
data_source = get_data_source()
metadata = await data_source.get_metadata(request.model_id)
request.fit_columns = list(metadata.input_schema.items.keys())
logger.info(
"fitColumns not specified, using all input columns for model %s: %s",
request.model_id,
request.fit_columns,
)

# Get the scheduler and validate availability
Expand Down
20 changes: 17 additions & 3 deletions src/endpoints/metrics/drift/kolmogorov_smirnov.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,13 @@ async def compute_kstest(
)

if not request.fit_columns:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail="fitColumns is required - specify which features to test for drift",
data_source = get_data_source()
metadata = await data_source.get_metadata(request.model_id)
request.fit_columns = list(metadata.input_schema.items.keys())
logger.info(
"fitColumns not specified, using all input columns for model %s: %s",
request.model_id,
request.fit_columns,
)

try:
Expand Down Expand Up @@ -186,6 +190,16 @@ async def get_kstest_definition() -> dict[str, str]:
@router.post("/metrics/drift/kstest/request")
async def schedule_kstest(request: KSTestMetricRequest) -> dict[str, str]:
"""Schedule a recurring computation of KSTest metric."""
if not request.fit_columns:
data_source = get_data_source()
metadata = await data_source.get_metadata(request.model_id)
request.fit_columns = list(metadata.input_schema.items.keys())
logger.info(
"fitColumns not specified, using all input columns for model %s: %s",
request.model_id,
request.fit_columns,
)

# Get the scheduler and validate availability
scheduler = get_prometheus_scheduler()
if not scheduler:
Expand Down
19 changes: 10 additions & 9 deletions tests/endpoints/metrics/drift/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,17 @@ def make_compute_endpoint_test(
def test_impl(_: object, mock_ds: MagicMock) -> None:
"""Test compute endpoint returns valid response structure."""
# Create sample dataframe (Pandas or Polars based on df_type)
sample_df = _create_sample_dataframe(
request_payload.get("fitColumns", ["feature1"]),
df_type=df_type,
)
columns = request_payload.get("fitColumns", ["feature1"])
sample_df = _create_sample_dataframe(columns, df_type=df_type)

# Mock data source
mock_data_source = MagicMock()
mock_data_source.get_dataframe_by_tag = AsyncMock(return_value=sample_df)
mock_data_source.get_organic_dataframe = AsyncMock(return_value=sample_df)
mock_data_source.get_dataframe = AsyncMock(return_value=sample_df)
mock_metadata = MagicMock()
mock_metadata.input_schema.items.keys.return_value = columns
mock_data_source.get_metadata = AsyncMock(return_value=mock_metadata)
mock_ds.return_value = mock_data_source

# Send request
Expand Down Expand Up @@ -338,17 +339,17 @@ def make_compute_endpoint_error_test(
def test_impl(_: object, mock_ds: MagicMock) -> None:
"""Test compute endpoint error handling."""
if setup_mocks:
# Create sample dataframe
sample_df = _create_sample_dataframe(
["feature1", "feature2"],
df_type=df_type,
)
columns = ["feature1", "feature2"]
sample_df = _create_sample_dataframe(columns, df_type=df_type)

# Mock data source
mock_data_source = MagicMock()
mock_data_source.get_dataframe_by_tag = AsyncMock(return_value=sample_df)
mock_data_source.get_organic_dataframe = AsyncMock(return_value=sample_df)
mock_data_source.get_dataframe = AsyncMock(return_value=sample_df)
mock_metadata = MagicMock()
mock_metadata.input_schema.items.keys.return_value = columns
mock_data_source.get_metadata = AsyncMock(return_value=mock_metadata)
mock_ds.return_value = mock_data_source

# Send request
Expand Down
24 changes: 12 additions & 12 deletions tests/endpoints/metrics/drift/test_compare_means.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,18 +292,18 @@ def test_multi_feature_drift_consistency_deterministic(self) -> None:
expected_error_substring="referenceTag is required",
)

test_compute_missing_fit_columns = factory.make_compute_endpoint_error_test(
metric_name="CompareMeans",
module_path="src.endpoints.metrics.drift.compare_means",
endpoint_path="/metrics/drift/comparemeans",
client=client,
request_payload={
"modelId": "test-model",
"referenceTag": "baseline",
# Missing fitColumns
},
expected_status_code=HTTPStatus.BAD_REQUEST,
expected_error_substring="fitColumns is required",
test_compute_missing_fit_columns_derives_from_metadata = (
factory.make_compute_endpoint_test(
metric_name="CompareMeans",
module_path="src.endpoints.metrics.drift.compare_means",
endpoint_path="/metrics/drift/comparemeans",
client=client,
request_payload={
"modelId": "test-model",
"referenceTag": "baseline",
},
expected_response_keys=["status", "value", "drift_detected"],
)
)

test_compute_invalid_feature = factory.make_compute_endpoint_error_test(
Expand Down
24 changes: 12 additions & 12 deletions tests/endpoints/metrics/drift/test_jensen_shannon.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,18 +129,18 @@ class TestJensenShannonEndpoints:
expected_error_substring="referenceTag is required",
)

test_compute_missing_fit_columns = factory.make_compute_endpoint_error_test(
metric_name="JensenShannon",
module_path="src.endpoints.metrics.drift.jensen_shannon",
endpoint_path="/metrics/drift/jensenshannon",
client=client,
request_payload={
"modelId": "test-model",
"referenceTag": "baseline",
# Missing fitColumns
},
expected_status_code=HTTPStatus.BAD_REQUEST,
expected_error_substring="fitColumns is required",
test_compute_missing_fit_columns_derives_from_metadata = (
factory.make_compute_endpoint_test(
metric_name="JensenShannon",
module_path="src.endpoints.metrics.drift.jensen_shannon",
endpoint_path="/metrics/drift/jensenshannon",
client=client,
request_payload={
"modelId": "test-model",
"referenceTag": "baseline",
},
expected_response_keys=["status", "value", "drift_detected"],
)
)

test_compute_invalid_feature = factory.make_compute_endpoint_error_test(
Expand Down
24 changes: 12 additions & 12 deletions tests/endpoints/metrics/drift/test_kolmogorov_smirnov.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,18 +114,18 @@ class TestKSTestEndpoints:
expected_error_substring="referenceTag is required",
)

test_compute_missing_fit_columns = factory.make_compute_endpoint_error_test(
metric_name="KSTest",
module_path="src.endpoints.metrics.drift.kolmogorov_smirnov",
endpoint_path="/metrics/drift/kstest",
client=client,
request_payload={
"modelId": "test-model",
"referenceTag": "baseline",
# Missing fitColumns
},
expected_status_code=HTTPStatus.BAD_REQUEST,
expected_error_substring="fitColumns is required",
test_compute_missing_fit_columns_derives_from_metadata = (
factory.make_compute_endpoint_test(
metric_name="KSTest",
module_path="src.endpoints.metrics.drift.kolmogorov_smirnov",
endpoint_path="/metrics/drift/kstest",
client=client,
request_payload={
"modelId": "test-model",
"referenceTag": "baseline",
},
expected_response_keys=["status", "value", "drift_detected"],
)
)

test_compute_invalid_feature = factory.make_compute_endpoint_error_test(
Expand Down
Loading