-
Notifications
You must be signed in to change notification settings - Fork 271
Add writing to zarr dataset for eval-mode of trained models #104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
2f5c32e
b80d010
305b8d0
de104f4
2199e47
1105955
2fe8859
e699f3f
a3cc461
c54accf
f57ab61
82aed31
b6abd4d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,11 +4,14 @@ | |
| from typing import List, Union | ||
|
|
||
| # Third-party | ||
| import cf_xarray as cfxr | ||
| import matplotlib.pyplot as plt | ||
| import numpy as np | ||
| import pandas as pd | ||
| import pytorch_lightning as pl | ||
| import torch | ||
| import xarray as xr | ||
| from loguru import logger | ||
|
|
||
| # Local | ||
| from .. import metrics, vis | ||
|
|
@@ -372,15 +375,109 @@ def on_validation_epoch_end(self): | |
| for metric_list in self.val_metrics.values(): | ||
| metric_list.clear() | ||
|
|
||
| def _save_predictions_to_zarr( | ||
| self, | ||
| batch_times: torch.Tensor, | ||
| batch_predictions: torch.Tensor, | ||
| batch_idx: int, | ||
| zarr_output_path: str, | ||
| ): | ||
| """ | ||
| Save state predictions for single batch to zarr dataset. Will append to | ||
| existing dataset for batch_idx > 0. Resulting dataset will contain a | ||
| variable named `state` with coordinates (start_time, | ||
| elapsed_forecast_duration, grid_index, state_feature). | ||
|
|
||
| Parameters | ||
| ---------- | ||
| batch_times : torch.Tensor[int] | ||
| The times for the batch, given as epoch time in nanoseconds. Shape | ||
| is (B, args.pred_steps) where B is the batch size and | ||
| args.pred_steps is the number of prediction steps. | ||
| batch_predictions : torch.Tensor[float] | ||
| The predictions for the batch, given as (B, args.pred_steps, | ||
| num_grid_nodes, d_f) where B is the batch size, args.pred_steps is | ||
| the number of prediction steps, num_grid_nodes is the number of | ||
| grid nodes, and d_f is the number of state features. | ||
| batch_idx : int | ||
| The index of the batch in the current epoch. | ||
| """ | ||
| batch_size = batch_predictions.shape[0] | ||
| # Convert predictions to DataArray using _create_dataarray_from_tensor | ||
| das_pred = [] | ||
| for i in range(len(batch_times)): | ||
| da_pred = self._create_dataarray_from_tensor( | ||
| tensor=batch_predictions[i], | ||
| time=batch_times[i], | ||
| split="test", | ||
| category="state", | ||
| ) | ||
|
|
||
| t0 = da_pred.coords["time"].values[0] | ||
| da_pred.coords["analysis_time"] = t0 | ||
| da_pred.coords["elapsed_forecast_duration"] = da_pred.time - t0 | ||
|
Comment on lines
+416
to
+418
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please have another look over how target times are used here. batch_times being fed in here are not [analysis_time, analysis_time+time_step, analysis_time+2time_step, ...], but rather [analysis_time+time_step, analysis_time+2time_step, ...]. This also because batch_predictions does not include the state at the analysis time. This means that start_time is currently not when the forecast was started.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just adjusting t0 would fix this: joeloskarsson@721ac5e However, you probably need to get the step length from somewhere else than in that commit, as self.step_length is not present on main. |
||
| # set CF-standard names for forecast data in anticipation of | ||
| # input/output to neural-lam eventually all being cf-compliant | ||
| da_pred.analysis_time.attrs[ | ||
| "standard_name" | ||
| ] = "forecast_reference_time" | ||
| da_pred.elapsed_forecast_duration.attrs[ | ||
| "standard_name" | ||
| ] = "forecast_period" | ||
| da_pred = da_pred.swap_dims({"time": "elapsed_forecast_duration"}) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we really want to leave the time coordinate in this DataArray? It looks to me like those values will only be valid for one of the forecasts (and I am not entirely sure which one). |
||
| da_pred.name = "state" | ||
|
|
||
| das_pred.append(da_pred) | ||
|
|
||
| da_pred_batch = xr.concat(das_pred, dim="start_time") | ||
|
|
||
| # Apply chunking along analysis_time so that each batch is saved as a | ||
| # separate chunk | ||
| da_pred_batch = da_pred_batch.chunk({"start_time": batch_size}) | ||
|
|
||
| # copy variables that contain the units and long_name attributes | ||
| # from the source datastore | ||
| # XXX: currently it is hardcoded in mllam-data-prep that these are | ||
| # called {data_category}_feature_{long_name,units}, this should | ||
| # probably be made a format string in mllam-data-prep so that these | ||
| # can be correctly parsed/constructed | ||
| for attr in ["long_name", "units"]: | ||
| var_name = f"state_feature_{attr}" | ||
| da_pred_batch.coords[var_name] = self._datastore._ds[var_name] | ||
|
|
||
| # to handle MultiIndexes (see below) we need to have an xr.Dataset, so | ||
| # we make that here. For now we are only making predictions for "state" | ||
| ds_pred_batch = da_pred_batch.to_dataset(name="state") | ||
|
|
||
| # we need to ensure that if `grid_index` is a MultiIndex, it is | ||
| # serialised so that it can be written to netcdf/zarr. We use | ||
| # `cf_xarray` for this (see | ||
| # https://cf-xarray.readthedocs.io/en/latest/coding.html) since they | ||
| # have implemented a cf-compliant way to safely roundtrip this | ||
| for idx_name in list(ds_pred_batch.indexes): | ||
| idx = ds_pred_batch.indexes[idx_name] | ||
| if isinstance(idx, pd.MultiIndex): | ||
| ds_pred_batch = cfxr.encode_multi_index_as_compress( | ||
| ds_pred_batch, idxnames=[idx_name] | ||
| ) | ||
|
|
||
| if batch_idx == 0: | ||
| logger.info(f"Saving predictions to {zarr_output_path}") | ||
| ds_pred_batch.to_zarr(zarr_output_path, mode="w", consolidated=True) | ||
| else: | ||
| ds_pred_batch.to_zarr( | ||
| zarr_output_path, mode="a", append_dim="start_time" | ||
| ) | ||
|
Comment on lines
+464
to
+470
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we might need to be a bit careful with the time encoding here. In my research branch I get errors like: Basically, the first time step is saved as days, because it might be 00:00UTC, but consecutive steps might be in e.g. hours. |
||
|
|
||
| # pylint: disable-next=unused-argument | ||
| def test_step(self, batch, batch_idx): | ||
| """ | ||
| Run test on single batch | ||
| """ | ||
| # TODO Here batch_times can be used for plotting routines | ||
| prediction, target, pred_std, batch_times = self.common_step(batch) | ||
| # prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B, | ||
| # pred_steps, num_grid_nodes, d_f) or (d_f,) | ||
| # prediction: (B, pred_steps, num_grid_nodes, d_f) | ||
| # pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,) | ||
|
|
||
| time_step_loss = torch.mean( | ||
| self.loss( | ||
|
|
@@ -436,6 +533,14 @@ def test_step(self, batch, batch_idx): | |
| self.spatial_loss_maps.append(log_spatial_losses) | ||
| # (B, N_log, num_grid_nodes) | ||
|
|
||
| if self.args.save_eval_to_zarr_path: | ||
| self._save_predictions_to_zarr( | ||
| batch_times=batch_times, | ||
| batch_predictions=prediction, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These predictions are in the standardized scale. At some point before these are written to disk in the zarr they should be rescaled to the original data scale.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can be done as joeloskarsson@d3f636e |
||
| batch_idx=batch_idx, | ||
| zarr_output_path=self.args.save_eval_to_zarr_path, | ||
| ) | ||
|
|
||
| # Plot example predictions (on rank 0 only) | ||
| if ( | ||
| self.trainer.is_global_zero | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| # Third-party | ||
| import pytest | ||
|
|
||
| # First-party | ||
| from neural_lam.train_model import main as train_model_main | ||
| from tests.conftest import init_datastore_example | ||
|
|
||
|
|
||
| @pytest.mark.dependency(depends=["test_training"]) | ||
| def test_inference(request): | ||
| """ | ||
| Run inference on a trained model and save the results to a zarr dataset | ||
| through the command line interface. | ||
|
|
||
| NB: This test will need refactoring once we clean up the command line | ||
| interface | ||
| """ | ||
| datastore = init_datastore_example("mdp") | ||
|
|
||
| # NB: this is brittle and should be refactored when the command line | ||
| # interface is cleaned up so that tests point to neural-lam config files | ||
| # rather than datastore config files | ||
| nl_config_path = datastore.root_path / "config.yaml" | ||
|
|
||
| # fetch the path to the trained model that was saved by the training test | ||
| model_path = request.config.cache.get("model_checkpoint_path", None) | ||
| if model_path is None: | ||
| raise Exception("training test must be run first") | ||
|
|
||
| args = [ | ||
| "--config_path", | ||
| nl_config_path, | ||
| "--model", | ||
| "graph_lam", | ||
| "--eval", | ||
| "test", | ||
| "--load", | ||
| model_path, | ||
| "--hidden_dim", | ||
| "4", | ||
| "--hidden_layers", | ||
| "1", | ||
| "--processor_layers", | ||
| "2", | ||
| "--mesh_aggr", | ||
| "sum", | ||
| "--lr", | ||
| "1.0e-3", | ||
| "--val_steps_to_log", | ||
| "1", | ||
| "3", | ||
| "--num_past_forcing_steps", | ||
| "1", | ||
| "--num_future_forcing_steps", | ||
| "1", | ||
| "--n_example_pred", | ||
| "1", | ||
| "--graph", | ||
| "1level", | ||
| "--save_eval_to_zarr_path", | ||
| "state_test.zarr", | ||
| ] | ||
|
|
||
| train_model_main(args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that _create_dataarray_from_tensor is also used here it seems unreasonable to not properly deal with the hack where a WeatherDataset is instantiatied at each call (
neural-lam/neural_lam/models/ar_model.py
Lines 182 to 185 in 1c281a2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an alternative hack to avoid constantly making new datasets: joeloskarsson@d277b08 however, this is still very much a hack and the TODO remains that we should handle this some proper way.