Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
905c2a2
feat(era5): default to hourly temporal resolution
fjsuarez Mar 16, 2026
34e40de
feat(swopp3): enforce operational weather constraints in CMA-ES
fjsuarez Mar 16, 2026
b55d85a
feat(land): smooth distance-to-land penalty via EDT precomputation
fjsuarez Mar 16, 2026
25934c2
fix: address PR #62 review comments
fjsuarez Mar 16, 2026
235bd70
fix(land): add _edt to Land bypass constructors in era5 loader
fjsuarez Mar 16, 2026
4da1f20
perf(swopp3): reuse windfield closure as vectorfield
fjsuarez Mar 16, 2026
2060916
fix(swopp3): free corridor data on GPU when switching corridors
fjsuarez Mar 16, 2026
fd9627b
fix(cost): use per-segment speed in evaluate_route_energy
fjsuarez Mar 16, 2026
7b3fe38
feat(swopp3): split SLURM jobs by corridor with hourly n-points
fjsuarez Mar 17, 2026
dea4c7a
fix(swopp3): repeat --cases flag for typer list option
fjsuarez Mar 17, 2026
1bdce69
fix(cmaes): add geographic bounds and strengthen land penalties
fjsuarez Mar 17, 2026
97f3db1
Potential fix for pull request finding
daniprec Mar 17, 2026
633f5a7
Potential fix for pull request finding
daniprec Mar 17, 2026
f46cea6
Potential fix for pull request finding
daniprec Mar 17, 2026
ed541db
Potential fix for pull request finding
daniprec Mar 17, 2026
a0ec7a0
fix(land): import distance_transform_edt for distance calculations
daniprec Mar 17, 2026
320c32f
fix(swopp3): reduce K=6 and remove bounds to prevent loops
fjsuarez Mar 17, 2026
00138fb
perf(swopp3): stage repo onto /scratch SSD for faster I/O
fjsuarez Mar 17, 2026
3cff22c
refactor(slurm): single staging job for /scratch rsync + submission w…
fjsuarez Mar 17, 2026
81e66e6
refactor(slurm): git pull instead of rsync for code updates on /scratch
fjsuarez Mar 17, 2026
e0157da
fix(slurm): use cp -rv to copy tracks subdirectory back to /home
fjsuarez Mar 17, 2026
658c3d8
feat(swopp3): add n-points resample comparison script
fjsuarez Mar 17, 2026
e7c5981
fix(slurm): handle existing /scratch dir without .git in staging
fjsuarez Mar 17, 2026
886266d
fix(npoints): fix timezone-aware datetime subtraction
fjsuarez Mar 17, 2026
c8c8998
fix(scripts): lint fixes and executable permissions for npoints scripts
fjsuarez Mar 17, 2026
faaa9ea
fix(scorer): resampling, trapezoidal rule, violations, interpolation …
fjsuarez Mar 18, 2026
1ea68a9
feat(cmaes): configurable weather penalty type and control points
fjsuarez Mar 18, 2026
b822abf
fix(slurm): run CPU job from /scratch, fix lint in test_resample_track
fjsuarez Mar 18, 2026
0fdd5b3
fix(validate): add magnitude check to strategy pair validation
fjsuarez Mar 18, 2026
9c4967c
feat: parameter sweep scripts for Stage A/B tuning
fjsuarez Mar 19, 2026
4354652
fix: revert ERA5 interpolation to order=1 (JAX limit)
fjsuarez Mar 19, 2026
f9742a5
feat: Atlantic parameter sweep + Pacific sweep results summary
fjsuarez Mar 19, 2026
78a9853
feat: cubic B-spline interpolation in JAX + no-constraint run script
fjsuarez Mar 19, 2026
3b53df0
feat: add --temporal-stride to reduce ERA5 JIT constant size
fjsuarez Mar 19, 2026
c459cc6
fix: use CPU mode for no-constraint run (avoids stride=3 artifacts)
fjsuarez Mar 19, 2026
2915abe
add: comparison plots for no-constraint experiment
fjsuarez Mar 19, 2026
a989627
feat(slurm): add SLURM script for 0526 fixed-penalty pipeline run
daniprec Jun 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions routetools/cmaes.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,8 @@ def _cma_evolution_strategy(
]
| None = None,
penalty: float = 1e10,
land_distance_weight: float = 0.0,
land_distance_epsilon: float = 1.0,
weather_penalty_weight: float = 0.0,
tws_limit: float = 20.0,
hs_limit: float = 7.0,
Expand Down Expand Up @@ -379,6 +381,12 @@ def _cma_evolution_strategy(
# toward fewer land points.
cost = jnp.where(has_land, penalty + land_count, cost)

# Smooth distance-to-land penalty via EDT
if land is not None and land_distance_weight > 0:
cost += land.distance_penalty(
curve, weight=land_distance_weight, epsilon=land_distance_epsilon
)

# Weather constraint penalization
if weather_penalty_weight > 0 and (
windfield is not None or wavefield is not None
Expand All @@ -390,6 +398,8 @@ def _cma_evolution_strategy(
tws_limit=tws_limit,
hs_limit=hs_limit,
penalty=weather_penalty_weight,
travel_time=travel_time,
time_offset=time_offset,
)

# Replace the worst solutions with the best found so far
Expand Down Expand Up @@ -435,6 +445,8 @@ def optimize(
]
| None = None,
penalty: float = 1e10,
land_distance_weight: float = 0.0,
land_distance_epsilon: float = 1.0,
weather_penalty_weight: float = 0.0,
tws_limit: float = 20.0,
hs_limit: float = 7.0,
Expand Down Expand Up @@ -488,6 +500,11 @@ def optimize(
penalty : float, optional
Large penalty applied to routes that intersect land (death-penalty
scheme), by default 1e10
land_distance_weight : float, optional
Weight for the smooth distance-to-land penalty via EDT.
Set to 0 (default) to disable.
land_distance_epsilon : float, optional
Regularisation constant for the EDT penalty (default 1.0).
weather_penalty_weight : float, optional
Penalty weight for weather constraint violations (TWS, Hs).
Set to 0 (default) to disable weather penalties.
Expand Down Expand Up @@ -594,6 +611,8 @@ def optimize(
wavefield=wavefield,
windfield=windfield,
penalty=penalty,
land_distance_weight=land_distance_weight,
land_distance_epsilon=land_distance_epsilon,
weather_penalty_weight=weather_penalty_weight,
tws_limit=tws_limit,
hs_limit=hs_limit,
Expand Down Expand Up @@ -667,6 +686,8 @@ def optimize(
tws_limit=tws_limit,
hs_limit=hs_limit,
penalty=weather_penalty_weight,
travel_time=travel_time,
time_offset=time_offset,
).item()
if cost_initial < cost_best:
warnings.warn(
Expand Down Expand Up @@ -706,6 +727,8 @@ def optimize_with_increasing_penalization(
penalty_init: float = 0,
penalty_increment: float = 10,
maxiter: int = 10,
land_distance_weight: float = 0.0,
land_distance_epsilon: float = 1.0,
weather_penalty_weight: float = 0.0,
tws_limit: float = 20.0,
hs_limit: float = 7.0,
Expand Down Expand Up @@ -823,6 +846,8 @@ def optimize_with_increasing_penalization(
wavefield=wavefield,
windfield=windfield,
penalty=penalty,
land_distance_weight=land_distance_weight,
land_distance_epsilon=land_distance_epsilon,
weather_penalty_weight=weather_penalty_weight,
tws_limit=tws_limit,
hs_limit=hs_limit,
Expand Down
2 changes: 1 addition & 1 deletion routetools/era5/download_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _select_corridor(
corridor: str,
year: str = "2024",
months: list[int] | None = None,
time_step: int = 6,
time_step: int = 1,
) -> xr.Dataset:
"""Subset dataset to a corridor, year, and temporal step.

Expand Down
56 changes: 56 additions & 0 deletions routetools/land.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from jax import jit
from jax.scipy.ndimage import map_coordinates
from perlin_numpy import generate_perlin_noise_2d as pn2d
from scipy.ndimage import distance_transform_edt
Comment thread
daniprec marked this conversation as resolved.
Outdated

from routetools._cost.haversine import haversine_meters_components

Expand Down Expand Up @@ -117,6 +118,14 @@ def __init__(
self._map_mode = map_mode
self._map_order = map_order

# Precompute EDT: distance (in grid cells) from each water cell
# to the nearest land cell. Land cells get distance 0.
binary_land = np.asarray(self._array > self.water_level)
# distance_transform_edt measures distance from 0-cells to nearest
# 1-cell, so we pass the *inverted* mask (water=0 → measure distance).
edt = distance_transform_edt(~binary_land)
self._edt = jnp.asarray(edt, dtype=jnp.float32)

@property
def array(self) -> jnp.ndarray:
"""Return a boolean array indicating land presence."""
Expand Down Expand Up @@ -279,6 +288,53 @@ def penalization(self, curve: jnp.ndarray, penalty: float) -> jnp.ndarray:
# Return the sum of the number of land intersections times the penalty
return jnp.sum(is_land, axis=1) * penalty

@partial(jit, static_argnums=(0,))
def distance_penalty(
self,
curve: jnp.ndarray,
weight: float = 1.0,
epsilon: float = 1.0,
) -> jnp.ndarray:
"""Smooth repulsive penalty based on precomputed EDT.

Samples the Euclidean Distance Transform at each waypoint and
returns ``weight * sum(1 / (edt + epsilon))`` per route. Points on
or very near land produce large penalties; points far from land
contribute negligibly. Uses ``map_coordinates`` for O(1) per-point
lookups and is fully JIT-compatible.

Parameters
----------
curve : jnp.ndarray
Batch of curves, shape ``(W, L, 2)`` with ``(lon, lat)``.
weight : float
Scaling factor for the penalty (default 1.0).
epsilon : float
Regularisation constant to avoid division by zero (default 1.0).

Returns
-------
jnp.ndarray
Penalty per route, shape ``(W,)``.
"""
x_coords = curve[..., 0]
y_coords = curve[..., 1]

x_norm = (x_coords - self.xmin) * self.xnorm
y_norm = (y_coords - self.ymin) * self.ynorm

# Sample the EDT field using instance interpolation settings
edt_vals = map_coordinates(
self._edt,
[x_norm, y_norm],
order=self._map_order,
mode=self._map_mode,
)

Comment thread
daniprec marked this conversation as resolved.
# Inverse-distance penalty: closer to land → larger cost
point_penalty = 1.0 / (edt_vals + epsilon)
return weight * jnp.sum(point_penalty, axis=-1)

def distance_to_land(
self, curve: jnp.ndarray, haversine: bool = False
) -> jnp.ndarray:
Expand Down
8 changes: 8 additions & 0 deletions routetools/swopp3_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,14 @@ def _rise_cost(curve_batch: jnp.ndarray) -> jnp.ndarray:
penalty=1000,
land_margin=2,
verbose=False,
# Operational weather constraints (SWOPP3: TWS < 20 m/s, Hs < 7 m)
windfield=windfield,
wavefield=wavefield,
weather_penalty_weight=100.0,
travel_time=travel_time,
time_offset=departure_offset_h,
# Smooth distance-to-land repulsion via EDT
land_distance_weight=10.0,
)
defaults.update(cmaes_kwargs)

Expand Down
25 changes: 21 additions & 4 deletions routetools/weather.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,18 +94,21 @@ def _segment_midpoints(
Shape ``(B, L, 2)`` with ``(lon, lat)``.
travel_stw : float, optional
Constant speed through water (m/s). Used to estimate elapsed time
per segment as ``distance / travel_stw``.
per segment as ``distance / travel_stw``; ``t_mid`` will be in
seconds when ``spherical_correction`` is ``True``.
travel_time : float, optional
Total travel time (seconds). Time is distributed proportionally
by segment distance.
Total travel time. Time is distributed proportionally by segment
distance. The units of ``travel_time`` carry through to ``t_mid``
(e.g. pass hours to get hours).
spherical_correction : bool
If ``True`` (default), use haversine distances (metres).

Returns
-------
mid_lon, mid_lat, t_mid : jnp.ndarray
Each of shape ``(B, L-1)``. ``t_mid`` is the estimated elapsed
time (in seconds) at the midpoint of each segment.
time at the midpoint of each segment, in the same units as
``travel_time`` or seconds when ``travel_stw`` is used.
"""
mid_lon = (curve[:, :-1, 0] + curve[:, 1:, 0]) / 2
mid_lat = (curve[:, :-1, 1] + curve[:, 1:, 1]) / 2
Expand Down Expand Up @@ -261,6 +264,7 @@ def weather_penalty(
travel_stw: float | None = None,
travel_time: float | None = None,
spherical_correction: bool = True,
time_offset: float = 0.0,
) -> jnp.ndarray:
"""Compute a hard penalty for weather constraint violations.

Expand Down Expand Up @@ -288,6 +292,11 @@ def weather_penalty(
Total travel time (seconds); distributed proportionally by distance.
spherical_correction : bool
Use haversine distances (default ``True``).
time_offset : float
Offset added to ``t_mid`` before querying field closures. Must
be in the same units as ``travel_time`` (or seconds when using
``travel_stw``). Typically the departure offset in hours when
``travel_time`` is also in hours.
Comment thread
daniprec marked this conversation as resolved.
Outdated

Returns
-------
Expand All @@ -300,6 +309,7 @@ def weather_penalty(
travel_time=travel_time,
spherical_correction=spherical_correction,
)
t_mid = t_mid + time_offset
Comment thread
daniprec marked this conversation as resolved.

violations = jnp.zeros(curve.shape[0])

Expand Down Expand Up @@ -337,6 +347,7 @@ def weather_penalty_smooth(
travel_stw: float | None = None,
travel_time: float | None = None,
spherical_correction: bool = True,
time_offset: float = 0.0,
) -> jnp.ndarray:
"""Compute a smooth (differentiable) penalty for weather violations.

Expand Down Expand Up @@ -369,6 +380,11 @@ def weather_penalty_smooth(
Total travel time (seconds); distributed proportionally by distance.
spherical_correction : bool
Use haversine distances (default ``True``).
time_offset : float
Offset added to ``t_mid`` before querying field closures. Must
be in the same units as ``travel_time`` (or seconds when using
``travel_stw``). Typically the departure offset in hours when
``travel_time`` is also in hours.
Comment thread
daniprec marked this conversation as resolved.
Outdated

Returns
-------
Expand All @@ -381,6 +397,7 @@ def weather_penalty_smooth(
travel_time=travel_time,
spherical_correction=spherical_correction,
)
t_mid = t_mid + time_offset

total = jnp.zeros(curve.shape[0])

Expand Down
29 changes: 14 additions & 15 deletions scripts/swopp3_slurm.sh
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
#!/bin/bash
#SBATCH --job-name=swopp3_0125
#SBATCH --job-name=swopp3_1h
#SBATCH --partition=cpu
#SBATCH --nodes=1
#SBATCH --cpus-per-task=64
#SBATCH --mem=128G
#SBATCH --mem=256G
#SBATCH --time=2-00:00:00
#SBATCH --output=slurm_%j.out
#SBATCH --error=slurm_%j.err

# ── SWOPP3 full run on rust-HPC (0.125° ERA5 data, CPU mode) ──
# ── SWOPP3 full run on rust-HPC (hourly ERA5 data, CPU mode) ──
#
# Submit: sbatch scripts/swopp3_slurm.sh
# Monitor: squeue -u $USER
Expand All @@ -28,9 +28,8 @@ export JAX_PLATFORMS=cpu
export XLA_FLAGS="${XLA_FLAGS:+$XLA_FLAGS }--xla_cpu_multi_thread_eigen=true --xla_force_host_platform_device_count=${SLURM_CPUS_PER_TASK}"

# ── Paths ──
DATA_025="data/era5"
DATA_0125="data/era5_0125"
OUTDIR="output/swopp3_0125_rust"
DATA="data/era5"
OUTDIR="output/swopp3_cpu"

mkdir -p "$OUTDIR"

Expand All @@ -40,16 +39,16 @@ echo "Date: $(date)"
echo "CPUs: ${SLURM_CPUS_PER_TASK}"
echo "Memory: ${SLURM_MEM_PER_NODE}MB"
echo "JAX: CPU mode"
echo "Data: 0.125°"
echo "Data: hourly ERA5"
echo "Output: ${OUTDIR}"
echo "======================================"

# Verify data is present
for f in \
"${DATA_0125}/era5_wind_atlantic_2024.nc" \
"${DATA_0125}/era5_waves_atlantic_2024.nc" \
"${DATA_0125}/era5_wind_pacific_2024.nc" \
"${DATA_0125}/era5_waves_pacific_2024.nc"; do
"${DATA}/era5_wind_atlantic_2024.nc" \
"${DATA}/era5_waves_atlantic_2024.nc" \
"${DATA}/era5_wind_pacific_2024.nc" \
"${DATA}/era5_waves_pacific_2024.nc"; do
if [[ ! -f "$f" ]]; then
echo "ERROR: Missing data file: $f" >&2
exit 1
Expand All @@ -66,10 +65,10 @@ echo "Starting SWOPP3 run at $(date)"
echo ""

python scripts/swopp3_run.py \
--wind-path-atlantic "${DATA_0125}/era5_wind_atlantic_2024.nc" \
--wave-path-atlantic "${DATA_0125}/era5_waves_atlantic_2024.nc" \
--wind-path-pacific "${DATA_0125}/era5_wind_pacific_2024.nc" \
--wave-path-pacific "${DATA_0125}/era5_waves_pacific_2024.nc" \
--wind-path-atlantic "${DATA}/era5_wind_atlantic_2024.nc" \
--wave-path-atlantic "${DATA}/era5_waves_atlantic_2024.nc" \
--wind-path-pacific "${DATA}/era5_wind_pacific_2024.nc" \
--wave-path-pacific "${DATA}/era5_waves_pacific_2024.nc" \
--output-dir "$OUTDIR"

echo ""
Expand Down
Loading