Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 66 additions & 10 deletions src/endpoints/metrics/drift/compare_means.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from http import HTTPStatus
from typing import Any

import pandas as pd
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, ConfigDict, Field, model_validator

Expand All @@ -18,6 +19,7 @@
from src.service.data.datasources.data_source import DataSource
from src.service.data.shared_data_source import get_shared_data_source
from src.service.payloads.metrics.base_metric_request import BaseMetricRequest
from src.service.prometheus.metric_value_carrier import MetricValueCarrier
from src.service.prometheus.prometheus_scheduler import PrometheusScheduler
from src.service.prometheus.shared_prometheus_scheduler import (
get_shared_prometheus_scheduler,
Expand All @@ -28,8 +30,8 @@
logger = logging.getLogger(__name__)

# Metric name constants
METRIC_NAME = "CompareMeans"
DEPRECATED_METRIC_NAME = "Meanshift" # Legacy name for backwards compatibility
METRIC_NAME = "COMPAREMEANS"
DEPRECATED_METRIC_NAME = "MEANSHIFT"

# Default parameter values
DEFAULT_BATCH_SIZE = 100
Expand Down Expand Up @@ -273,7 +275,8 @@ async def schedule_compare_means(request: CompareMeansMetricRequest) -> dict[str
logger.info("Scheduling %s computation with ID: %s.", METRIC_NAME, request_id)

# Set metric name automatically
request.metric_name = METRIC_NAME
if not request.metric_name:
request.metric_name = METRIC_NAME

# Register with the scheduler (this will reconcile the request and store it)
await scheduler.register(request.metric_name, request_id, request)
Expand All @@ -294,7 +297,9 @@ async def schedule_compare_means(request: CompareMeansMetricRequest) -> dict[str


@router.delete("/metrics/drift/comparemeans/request")
async def delete_compare_means_schedule(schedule: ScheduleId) -> dict[str, str]:
async def delete_compare_means_schedule(
schedule: ScheduleId, metric_name: str = METRIC_NAME
) -> dict[str, str]:
"""Delete a recurring computation of CompareMeans metric."""
# Get the scheduler and validate availability
scheduler = get_prometheus_scheduler()
Expand All @@ -316,7 +321,7 @@ async def delete_compare_means_schedule(schedule: ScheduleId) -> dict[str, str]:
logger.info("Deleting %s schedule: %s", METRIC_NAME, schedule.requestId)

# Delete from scheduler
await scheduler.delete(METRIC_NAME, request_uuid)
await scheduler.delete(metric_name, request_uuid)

except HTTPException:
raise
Expand All @@ -339,7 +344,9 @@ async def delete_compare_means_schedule(schedule: ScheduleId) -> dict[str, str]:


@router.get("/metrics/drift/comparemeans/requests")
async def list_compare_means_requests() -> dict[str, list[dict[str, Any]]]:
async def list_compare_means_requests(
metric_name: str = METRIC_NAME,
) -> dict[str, list[dict[str, Any]]]:
"""List the currently scheduled computations of CompareMeans metric."""
# Get the scheduler and validate availability
scheduler = get_prometheus_scheduler()
Expand All @@ -350,8 +357,7 @@ async def list_compare_means_requests() -> dict[str, list[dict[str, Any]]]:
)

try:
# Get all requests for CompareMeans
requests = scheduler.get_requests(METRIC_NAME)
requests = scheduler.get_requests(metric_name)

# Convert to list format expected by client
requests_list = []
Expand Down Expand Up @@ -460,6 +466,7 @@ async def schedule_meanshift(request: MeanshiftMetricRequest) -> dict[str, str]:
compare_means_request = CompareMeansMetricRequest.model_validate(
request.model_dump(exclude_none=True)
)
compare_means_request.metric_name = DEPRECATED_METRIC_NAME
return await schedule_compare_means(compare_means_request)


Expand All @@ -471,7 +478,9 @@ async def delete_meanshift_schedule(schedule: ScheduleId) -> dict[str, str]:
/metrics/drift/comparemeans/request instead.
"""
log_deprecated_endpoint(logger, DEPRECATED_METRIC_NAME, METRIC_NAME)
return await delete_compare_means_schedule(schedule)
return await delete_compare_means_schedule(
schedule, metric_name=DEPRECATED_METRIC_NAME
)


@router.get("/metrics/drift/meanshift/requests", deprecated=True)
Expand All @@ -482,4 +491,51 @@ async def list_meanshift_requests() -> dict[str, list[dict[str, Any]]]:
/metrics/drift/comparemeans/requests instead.
"""
log_deprecated_endpoint(logger, DEPRECATED_METRIC_NAME, METRIC_NAME)
return await list_compare_means_requests()
return await list_compare_means_requests(metric_name=DEPRECATED_METRIC_NAME)


async def calculate_compare_means_metric(
batch: pd.DataFrame,
request: BaseMetricRequest,
) -> MetricValueCarrier:
"""Calculate CompareMeans metric for the Prometheus scheduler."""
data_source = get_data_source()
reference_df = await data_source.get_dataframe_by_tag(
request.model_id, request.reference_tag
)
fit_columns = request.fit_columns or list(batch.columns)
alpha = getattr(request, "alpha", DEFAULT_ALPHA)
equal_var = getattr(request, "equal_var", DEFAULT_EQUAL_VAR)
nan_policy = getattr(request, "nan_policy", DEFAULT_NAN_POLICY)

named_values = {}
for feature_name in fit_columns:
if feature_name in reference_df.columns and feature_name in batch.columns:
result = CompareMeans.ttest_ind(
reference_data=reference_df[feature_name].to_numpy(),
current_data=batch[feature_name].to_numpy(),
alpha=alpha,
equal_var=equal_var,
nan_policy=nan_policy,
)
named_values[feature_name] = result["statistic"]
return MetricValueCarrier(named_values or 0.0)


def _register_compare_means_calculator() -> None:
"""Register the CompareMeans calculator with the metrics directory."""
scheduler = get_prometheus_scheduler()
if scheduler and scheduler.metrics_directory:
scheduler.metrics_directory.register(
METRIC_NAME, calculate_compare_means_metric
)
scheduler.metrics_directory.register(
DEPRECATED_METRIC_NAME, calculate_compare_means_metric
)
logger.info("%s calculator registered with metrics directory", METRIC_NAME)


try:
_register_compare_means_calculator()
except (AttributeError, TypeError) as e:
logger.warning("Could not register %s calculator on import: %s", METRIC_NAME, e)
64 changes: 59 additions & 5 deletions src/endpoints/metrics/drift/jensen_shannon.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from http import HTTPStatus
from typing import Any, Literal, cast

import pandas as pd
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, ConfigDict, Field, model_validator

Expand All @@ -18,6 +19,7 @@
)
from src.service.data.shared_data_source import DataSource, get_shared_data_source
from src.service.payloads.metrics.base_metric_request import BaseMetricRequest
from src.service.prometheus.metric_value_carrier import MetricValueCarrier
from src.service.prometheus.prometheus_scheduler import PrometheusScheduler
from src.service.prometheus.shared_prometheus_scheduler import (
get_shared_prometheus_scheduler,
Expand All @@ -27,7 +29,7 @@
logger = logging.getLogger(__name__)

# Metric name constant
METRIC_NAME = "JensenShannon"
METRIC_NAME = "JENSENSHANNON"


def get_prometheus_scheduler() -> PrometheusScheduler:
Expand Down Expand Up @@ -269,7 +271,9 @@ async def schedule_jensenshannon(request: JensenShannonMetricRequest) -> dict[st


@router.delete("/metrics/drift/jensenshannon/request")
async def delete_jensenshannon_schedule(schedule: ScheduleId) -> dict[str, str]:
async def delete_jensenshannon_schedule(
schedule: ScheduleId, metric_name: str = METRIC_NAME
) -> dict[str, str]:
"""Delete a recurring computation of Jensen-Shannon metric."""
# Get the scheduler and validate availability
scheduler = get_prometheus_scheduler()
Expand All @@ -291,7 +295,7 @@ async def delete_jensenshannon_schedule(schedule: ScheduleId) -> dict[str, str]:
logger.info("Deleting %s schedule: %s", METRIC_NAME, schedule.requestId)

# Delete from scheduler
await scheduler.delete(METRIC_NAME, request_uuid)
await scheduler.delete(metric_name, request_uuid)

except HTTPException:
raise
Expand All @@ -314,7 +318,9 @@ async def delete_jensenshannon_schedule(schedule: ScheduleId) -> dict[str, str]:


@router.get("/metrics/drift/jensenshannon/requests")
async def list_jensenshannon_requests() -> dict[str, list[dict[str, Any]]]:
async def list_jensenshannon_requests(
metric_name: str = METRIC_NAME,
) -> dict[str, list[dict[str, Any]]]:
"""List the currently scheduled computations of Jensen-Shannon metric."""
# Get the scheduler and validate availability
scheduler = get_prometheus_scheduler()
Expand All @@ -326,7 +332,7 @@ async def list_jensenshannon_requests() -> dict[str, list[dict[str, Any]]]:

try:
# Get all requests for JensenShannon
requests = scheduler.get_requests(METRIC_NAME)
requests = scheduler.get_requests(metric_name)

# Convert to list format expected by client
requests_list = []
Expand Down Expand Up @@ -374,3 +380,51 @@ async def list_jensenshannon_requests() -> dict[str, list[dict[str, Any]]]:
) from e
else:
return {"requests": requests_list}


async def calculate_jensenshannon_metric(
batch: pd.DataFrame,
request: BaseMetricRequest,
) -> MetricValueCarrier:
"""Calculate JensenShannon metric for the Prometheus scheduler."""
data_source = get_data_source()
reference_df = await data_source.get_dataframe_by_tag(
request.model_id, request.reference_tag
)
fit_columns = request.fit_columns or list(batch.columns)
statistic = getattr(request, "statistic", DEFAULT_STATISTIC)
threshold = getattr(request, "threshold", DEFAULT_THRESHOLD)
method = getattr(request, "method", DEFAULT_METHOD)
grid_points = getattr(request, "grid_points", DEFAULT_GRID_POINTS)
bins = getattr(request, "bins", DEFAULT_BINS)

named_values = {}
for feature_name in fit_columns:
if feature_name in reference_df.columns and feature_name in batch.columns:
result = JensenShannon.jensenshannon(
data_ref=reference_df[feature_name].to_numpy(),
data_cur=batch[feature_name].to_numpy(),
statistic=cast("Literal['distance', 'divergence']", statistic),
threshold=threshold,
method=cast("Literal['kde', 'hist']", method),
grid_points=grid_points,
bins=bins,
)
named_values[feature_name] = result["Jensen-Shannon_distance"]
return MetricValueCarrier(named_values or 0.0)


def _register_jensenshannon_calculator() -> None:
"""Register the JensenShannon calculator with the metrics directory."""
scheduler = get_prometheus_scheduler()
if scheduler and scheduler.metrics_directory:
scheduler.metrics_directory.register(
METRIC_NAME, calculate_jensenshannon_metric
)
logger.info("%s calculator registered with metrics directory", METRIC_NAME)


try:
_register_jensenshannon_calculator()
except (AttributeError, TypeError) as e:
logger.warning("Could not register %s calculator on import: %s", METRIC_NAME, e)
58 changes: 51 additions & 7 deletions src/endpoints/metrics/drift/kolmogorov_smirnov.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from http import HTTPStatus
from typing import Any

import pandas as pd
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, ConfigDict, Field

from src.core.metrics.drift.kolmogorov_smirnov import KolmogorovSmirnov
from src.service.data.datasources.data_source import DataSource
from src.service.data.shared_data_source import get_shared_data_source
from src.service.payloads.metrics.base_metric_request import BaseMetricRequest
from src.service.prometheus.metric_value_carrier import MetricValueCarrier
from src.service.prometheus.prometheus_scheduler import PrometheusScheduler
from src.service.prometheus.shared_prometheus_scheduler import (
get_shared_prometheus_scheduler,
Expand All @@ -21,7 +23,7 @@
logger = logging.getLogger(__name__)

# Metric name constant
METRIC_NAME = "KSTest"
METRIC_NAME = "KSTEST"


def get_prometheus_scheduler() -> PrometheusScheduler:
Expand Down Expand Up @@ -199,8 +201,8 @@ async def schedule_kstest(request: KSTestMetricRequest) -> dict[str, str]:
request_id = uuid.uuid4()
logger.info("Scheduling %s computation with ID: %s.", METRIC_NAME, request_id)

# Set metric name automatically
request.metric_name = METRIC_NAME
if not request.metric_name:
request.metric_name = METRIC_NAME

# Register with the scheduler (this will reconcile the request and store it)
await scheduler.register(request.metric_name, request_id, request)
Expand All @@ -219,7 +221,9 @@ async def schedule_kstest(request: KSTestMetricRequest) -> dict[str, str]:


@router.delete("/metrics/drift/kstest/request")
async def delete_kstest_schedule(schedule: ScheduleId) -> dict[str, str]:
async def delete_kstest_schedule(
schedule: ScheduleId, metric_name: str = METRIC_NAME
) -> dict[str, str]:
"""Delete a recurring computation of KSTest metric."""
# Get the scheduler and validate availability
scheduler = get_prometheus_scheduler()
Expand All @@ -241,7 +245,7 @@ async def delete_kstest_schedule(schedule: ScheduleId) -> dict[str, str]:
logger.info("Deleting %s schedule: %s", METRIC_NAME, schedule.requestId)

# Delete from scheduler
await scheduler.delete(METRIC_NAME, request_uuid)
await scheduler.delete(metric_name, request_uuid)

except HTTPException:
raise
Expand All @@ -264,7 +268,9 @@ async def delete_kstest_schedule(schedule: ScheduleId) -> dict[str, str]:


@router.get("/metrics/drift/kstest/requests")
async def list_kstest_requests() -> dict[str, list[dict[str, Any]]]:
async def list_kstest_requests(
metric_name: str = METRIC_NAME,
) -> dict[str, list[dict[str, Any]]]:
"""List the currently scheduled computations of KSTest metric."""
# Get the scheduler and validate availability
scheduler = get_prometheus_scheduler()
Expand All @@ -276,7 +282,7 @@ async def list_kstest_requests() -> dict[str, list[dict[str, Any]]]:

try:
# Get all requests for KSTest
requests = scheduler.get_requests(METRIC_NAME)
requests = scheduler.get_requests(metric_name)

# Convert to list format expected by client
requests_list = []
Expand Down Expand Up @@ -319,3 +325,41 @@ async def list_kstest_requests() -> dict[str, list[dict[str, Any]]]:
) from e
else:
return {"requests": requests_list}


async def calculate_kstest_metric(
batch: pd.DataFrame,
request: BaseMetricRequest,
) -> MetricValueCarrier:
"""Calculate KSTest metric for the Prometheus scheduler."""
data_source = get_data_source()
reference_df = await data_source.get_dataframe_by_tag(
request.model_id, request.reference_tag
)
fit_columns = request.fit_columns or list(batch.columns)
alpha = getattr(request, "threshold_delta", 0.05)

named_values = {}
for feature_name in fit_columns:
if feature_name in reference_df.columns and feature_name in batch.columns:
result = KolmogorovSmirnov.kstest(
reference_data=reference_df[feature_name].to_numpy(),
current_data=batch[feature_name].to_numpy(),
alpha=alpha,
)
named_values[feature_name] = result["statistic"]
return MetricValueCarrier(named_values or 0.0)


def _register_kstest_calculator() -> None:
"""Register the KSTest calculator with the metrics directory."""
scheduler = get_prometheus_scheduler()
if scheduler and scheduler.metrics_directory:
scheduler.metrics_directory.register(METRIC_NAME, calculate_kstest_metric)
logger.info("%s calculator registered with metrics directory", METRIC_NAME)


try:
_register_kstest_calculator()
except (AttributeError, TypeError) as e:
logger.warning("Could not register %s calculator on import: %s", METRIC_NAME, e)
Loading
Loading