Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions neural_lam/datastore/mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# Third-party
import cartopy.crs as ccrs
import cf_xarray as cfxr
import mllam_data_prep as mdp
import xarray as xr
from loguru import logger
Expand Down Expand Up @@ -68,6 +69,11 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
)
self._ds = xr.open_zarr(fp_ds, consolidated=True)

# XXX: make decoding of MultiIndex be based on the mdp version
self._ds = cfxr.decode_compress_to_multi_index(
self._ds, idxnames="grid_index"
)

if self._ds is None:
self._ds = mdp.create_dataset(config=self._config)
self._ds.to_zarr(fp_ds)
Expand Down
109 changes: 107 additions & 2 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
from typing import List, Union

# Third-party
import cf_xarray as cfxr
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import xarray as xr
from loguru import logger

# Local
from .. import metrics, vis
Expand Down Expand Up @@ -372,15 +375,109 @@ def on_validation_epoch_end(self):
for metric_list in self.val_metrics.values():
metric_list.clear()

def _save_predictions_to_zarr(
self,
batch_times: torch.Tensor,
batch_predictions: torch.Tensor,
batch_idx: int,
zarr_output_path: str,
):
"""
Save state predictions for single batch to zarr dataset. Will append to
existing dataset for batch_idx > 0. Resulting dataset will contain a
variable named `state` with coordinates (start_time,
elapsed_forecast_duration, grid_index, state_feature).

Parameters
----------
batch_times : torch.Tensor[int]
The times for the batch, given as epoch time in nanoseconds. Shape
is (B, args.pred_steps) where B is the batch size and
args.pred_steps is the number of prediction steps.
batch_predictions : torch.Tensor[float]
The predictions for the batch, given as (B, args.pred_steps,
num_grid_nodes, d_f) where B is the batch size, args.pred_steps is
the number of prediction steps, num_grid_nodes is the number of
grid nodes, and d_f is the number of state features.
batch_idx : int
The index of the batch in the current epoch.
"""
batch_size = batch_predictions.shape[0]
# Convert predictions to DataArray using _create_dataarray_from_tensor
das_pred = []
for i in range(len(batch_times)):
da_pred = self._create_dataarray_from_tensor(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that _create_dataarray_from_tensor is also used here it seems unreasonable to not properly deal with the hack where a WeatherDataset is instantiatied at each call (

# TODO: creating an instance of WeatherDataset here on every call is
# not how this should be done but whether WeatherDataset should be
# provided to ARModel or where to put plotting still needs discussion
weather_dataset = WeatherDataset(datastore=self._datastore, split=split)
). As is we are doing instatiation of O(NT) WeatherDatasets when saving to zarr (not a memory problem, as we throw them away, but very wastefull). This will quickly become a problem, as the WeatherDatasets will grow when we merge in more boundar-related changes.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an alternative hack to avoid constantly making new datasets: joeloskarsson@d277b08 however, this is still very much a hack and the TODO remains that we should handle this some proper way.

tensor=batch_predictions[i],
time=batch_times[i],
split="test",
category="state",
)

t0 = da_pred.coords["time"].values[0]
da_pred.coords["analysis_time"] = t0
da_pred.coords["elapsed_forecast_duration"] = da_pred.time - t0
Comment on lines +416 to +418

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please have another look over how target times are used here. batch_times being fed in here are not [analysis_time, analysis_time+time_step, analysis_time+2time_step, ...], but rather [analysis_time+time_step, analysis_time+2time_step, ...]. This also because batch_predictions does not include the state at the analysis time.

This means that start_time is currently not when the forecast was started.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just adjusting t0 would fix this: joeloskarsson@721ac5e However, you probably need to get the step length from somewhere else than in that commit, as self.step_length is not present on main.

# set CF-standard names for forecast data in anticipation of
# input/output to neural-lam eventually all being cf-compliant
da_pred.analysis_time.attrs[
"standard_name"
] = "forecast_reference_time"
da_pred.elapsed_forecast_duration.attrs[
"standard_name"
] = "forecast_period"
da_pred = da_pred.swap_dims({"time": "elapsed_forecast_duration"})

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really want to leave the time coordinate in this DataArray? It looks to me like those values will only be valid for one of the forecasts (and I am not entirely sure which one).

da_pred.name = "state"

das_pred.append(da_pred)

da_pred_batch = xr.concat(das_pred, dim="start_time")

# Apply chunking along analysis_time so that each batch is saved as a
# separate chunk
da_pred_batch = da_pred_batch.chunk({"start_time": batch_size})

# copy variables that contain the units and long_name attributes
# from the source datastore
# XXX: currently it is hardcoded in mllam-data-prep that these are
# called {data_category}_feature_{long_name,units}, this should
# probably be made a format string in mllam-data-prep so that these
# can be correctly parsed/constructed
for attr in ["long_name", "units"]:
var_name = f"state_feature_{attr}"
da_pred_batch.coords[var_name] = self._datastore._ds[var_name]

# to handle MultiIndexes (see below) we need to have an xr.Dataset, so
# we make that here. For now we are only making predictions for "state"
ds_pred_batch = da_pred_batch.to_dataset(name="state")

# we need to ensure that if `grid_index` is a MultiIndex, it is
# serialised so that it can be written to netcdf/zarr. We use
# `cf_xarray` for this (see
# https://cf-xarray.readthedocs.io/en/latest/coding.html) since they
# have implemented a cf-compliant way to safely roundtrip this
for idx_name in list(ds_pred_batch.indexes):
idx = ds_pred_batch.indexes[idx_name]
if isinstance(idx, pd.MultiIndex):
ds_pred_batch = cfxr.encode_multi_index_as_compress(
ds_pred_batch, idxnames=[idx_name]
)

if batch_idx == 0:
logger.info(f"Saving predictions to {zarr_output_path}")
ds_pred_batch.to_zarr(zarr_output_path, mode="w", consolidated=True)
else:
ds_pred_batch.to_zarr(
zarr_output_path, mode="a", append_dim="start_time"
)
Comment on lines +464 to +470

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we might need to be a bit careful with the time encoding here. In my research branch I get errors like:

UserWarning: Times can't be serialized faithfully to int64 with requested units 'days since 2020-02-12T01:20:00'. Serializing with units 'hours since 2020-02-12T01:20:00' instead. Set encoding['dtype'] to floating point dtype to serialize with units 'days since 2020-02-12T01:20:00'. Set encoding['units'] to 'hours since 2020-02-12T01:20:00' to silence this warning.

Basically, the first time step is saved as days, because it might be 00:00UTC, but consecutive steps might be in e.g. hours.


# pylint: disable-next=unused-argument
def test_step(self, batch, batch_idx):
"""
Run test on single batch
"""
# TODO Here batch_times can be used for plotting routines
prediction, target, pred_std, batch_times = self.common_step(batch)
# prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B,
# pred_steps, num_grid_nodes, d_f) or (d_f,)
# prediction: (B, pred_steps, num_grid_nodes, d_f)
# pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,)

time_step_loss = torch.mean(
self.loss(
Expand Down Expand Up @@ -436,6 +533,14 @@ def test_step(self, batch, batch_idx):
self.spatial_loss_maps.append(log_spatial_losses)
# (B, N_log, num_grid_nodes)

if self.args.save_eval_to_zarr_path:
self._save_predictions_to_zarr(
batch_times=batch_times,
batch_predictions=prediction,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These predictions are in the standardized scale. At some point before these are written to disk in the zarr they should be rescaled to the original data scale.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be done as joeloskarsson@d3f636e

batch_idx=batch_idx,
zarr_output_path=self.args.save_eval_to_zarr_path,
)

# Plot example predictions (on rank 0 only)
if (
self.trainer.is_global_zero
Expand Down
7 changes: 6 additions & 1 deletion neural_lam/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
}


@logger.catch
@logger.catch(reraise=True)
def main(input_args=None):
"""Main function for training and evaluating models."""
parser = ArgumentParser(
Expand Down Expand Up @@ -166,6 +166,11 @@ def main(input_args=None):
help="Eval model on given data split (val/test) "
"(default: None (train model))",
)
parser.add_argument(
"--save_eval_to_zarr_path",
type=str,
help="Save evaluation results to zarr dataset at given path ",
)
parser.add_argument(
"--ar_steps_eval",
type=int,
Expand Down
13 changes: 11 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,21 @@ dependencies = [
"torch-geometric==2.3.1",
"parse>=1.20.2",
"dataclass-wizard<0.31.0",
"mllam-data-prep>=0.5.0",
"mlflow>=2.16.2",
"boto3>=1.35.32",
"pynvml>=12.0.0",
"cf-xarray>=0.9.4",
"mllam-data-prep",
]
requires-python = ">=3.9"

[project.optional-dependencies]
dev = ["pre-commit>=3.8.0", "pytest>=8.3.2", "pooch>=1.8.2"]
dev = [
"pre-commit>=3.8.0",
"pytest>=8.3.2",
"pooch>=1.8.2",
"pytest-dependency>=0.6.0",
]

[tool.setuptools]
py-modules = ["neural_lam"]
Expand Down Expand Up @@ -111,6 +117,9 @@ min-similarity-lines = 10
source = "scm"
fallback_version = "0.0.0"

[tool.uv.sources]
mllam-data-prep = { git = "https://github.com/leifdenby/mllam-data-prep", rev = "feat/inverse-ops" }

[build-system]
requires = ["pdm-backend"]
build-backend = "pdm.backend"
64 changes: 64 additions & 0 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Third-party
import pytest

# First-party
from neural_lam.train_model import main as train_model_main
from tests.conftest import init_datastore_example


@pytest.mark.dependency(depends=["test_training"])
def test_inference(request):
"""
Run inference on a trained model and save the results to a zarr dataset
through the command line interface.

NB: This test will need refactoring once we clean up the command line
interface
"""
datastore = init_datastore_example("mdp")

# NB: this is brittle and should be refactored when the command line
# interface is cleaned up so that tests point to neural-lam config files
# rather than datastore config files
nl_config_path = datastore.root_path / "config.yaml"

# fetch the path to the trained model that was saved by the training test
model_path = request.config.cache.get("model_checkpoint_path", None)
if model_path is None:
raise Exception("training test must be run first")

args = [
"--config_path",
nl_config_path,
"--model",
"graph_lam",
"--eval",
"test",
"--load",
model_path,
"--hidden_dim",
"4",
"--hidden_layers",
"1",
"--processor_layers",
"2",
"--mesh_aggr",
"sum",
"--lr",
"1.0e-3",
"--val_steps_to_log",
"1",
"3",
"--num_past_forcing_steps",
"1",
"--num_future_forcing_steps",
"1",
"--n_example_pred",
"1",
"--graph",
"1level",
"--save_eval_to_zarr_path",
"state_test.zarr",
]

train_model_main(args)
10 changes: 9 additions & 1 deletion tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
from tests.conftest import init_datastore_example


@pytest.mark.dependency()
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_training(datastore_name):
def test_training(datastore_name, request):
datastore = init_datastore_example(datastore_name)

if not isinstance(datastore, BaseRegularGridDatastore):
Expand Down Expand Up @@ -104,3 +105,10 @@ class ModelArgs:
)
wandb.init()
trainer.fit(model=model, datamodule=data_module)

# save the path to the model checkpoint in to the request object so we can
# use in the inference test
request.config.cache.set(
"model_checkpoint_path",
model.trainer.checkpoint_callback.best_model_path,
)