From 1028d58d0ac74e1aad28825582c03fbf34e4bbfa Mon Sep 17 00:00:00 2001 From: VARENNES Robin Date: Tue, 20 Jan 2026 13:52:08 +0100 Subject: [PATCH] Added inline compution of time derivative of real fields. --- src/interfaces/input_reader.py | 7 +-- src/interfaces/output_saver.py | 5 ++- src/simulation/run_simulation.py | 75 +++++++++++++++++++++++++++++--- 3 files changed, 77 insertions(+), 10 deletions(-) diff --git a/src/interfaces/input_reader.py b/src/interfaces/input_reader.py index 204e9a4..6699bdb 100644 --- a/src/interfaces/input_reader.py +++ b/src/interfaces/input_reader.py @@ -280,11 +280,12 @@ def _set_default(self): ("source_gauss_sigma_x" , "source" , [0.]), ("source_gauss_sigma_y" , "source" , [0.]), # Inline operations - ("fft_filter" , "inline_operations", True), + ("fft_filter" , "inline_operations", True), + ("compute_time_derivatives" , "inline_operations", False), # Callbacks - ("check_crash" , "callbacks", True), + ("check_crash" , "callbacks", True), ("save_real" , "callbacks", True), - ("save_fft" , "callbacks", False), + ("save_fft" , "callbacks", False), ] def default(key, category, default_value): diff --git a/src/interfaces/output_saver.py b/src/interfaces/output_saver.py index c1b7c3d..1217519 100644 --- a/src/interfaces/output_saver.py +++ b/src/interfaces/output_saver.py @@ -123,11 +123,12 @@ def _make_output_folder(self, output_folder): Path(output_folder).mkdir(parents=True, exist_ok=True) return Path(output_folder) - def save_output(self, fields, step, t): + def save_output(self, fields, step, t, overwrite=False): with h5.File(self.output_folder / f'fields_{step:05d}.h5', 'a') as f: for field in fields: + if overwrite and (field in f): + del f[field] f.create_dataset(field, data=fields[field]) - # f.create_dataset('time', data=t) f.require_dataset('time', shape=(), dtype='f', data=t) def save_output_data_file(self, erase_files=True): diff --git a/src/simulation/run_simulation.py b/src/simulation/run_simulation.py index 9a922ee..d2a15a1 100644 --- a/src/simulation/run_simulation.py +++ b/src/simulation/run_simulation.py @@ -75,13 +75,19 @@ def __init__(self, params, saver=None): self.Nt_rk4 = params.user["time"]["Nt_rk4"] self.Nt_diag = params.user["time"]["Nt_diag"] self.dt_rk4 = params.user["time"]["dt_rk4"] + self.dt_diag = params.user["time"]["dt_diag"] self.rk4_per_diag = params.user["time"]["rk4_per_diag"] + self.Tsim = params.user["time"]["Tsim"] self.time = 0.0 self.step_rk4_count = 0 self.step_diag_count = 0 self._stop = False + # Retrieve mesh + self.Nx = params.user["grid"]["Nx"] + self.Ny = params.user["grid"]["Ny"] + # Retrieve and initialize the operation occuring during the loop self.user_callbacks = params.user.get("callbacks", {}) self.user_inline_operations = params.user.get("inline_operations", {}) @@ -98,6 +104,7 @@ def __init__(self, params, saver=None): self._init_callbacks() self._init_inline_operations() self._init_inline_display() + self._init_inline_compute_time_derivative() def run(self, fields): """ @@ -106,6 +113,9 @@ def run(self, fields): """ while self.step_rk4_count < self.Nt_rk4 and not self._stop: + if self.user_inline_operations.get('compute_time_derivatives', False): + self.inline_compute_time_derivative(fields) + if self.step_rk4_count % self.rk4_per_diag == 0: # Trigger callbacks (e.g. saving, crash checking) at diagnostic time step for cb in self.callbacks: @@ -121,26 +131,81 @@ def run(self, fields): self.step_diag_count += 1 fields = step(fields, self.pde, self.dt_rk4, t=self.time) + self.time += self.dt_rk4 self.step_rk4_count += 1 + return fields + def _init_inline_compute_time_derivative(self): + if not self.user_inline_operations.get('compute_time_derivatives', False): + return + self.logger.info("--> Enabling inline computation of time derivatives...") + + # Efficient buffers: store 9 last time steps to estimate time derivative using central finite differences + from collections import deque + self.buffer_time = deque(maxlen=9) + self.buffer_fields = deque(maxlen=9) + + # Coefficients for 8th order central finite differences + coef = [1./280., -4./105., 1./5., -4./5., 0., 4./5., -1./5., 4./105., -1./280.] + self.coef = [x/self.dt_rk4 for x in coef] # Apply the division by the time step h + + return + + def inline_compute_time_derivative(self, fields): + # Central FD cannot be computed for firsts and lasts time steps, + # So until custom stencils are implemented, we set the derivative to 0 + no_cfd_lowercond = self.time < 9*self.dt_rk4 + no_cfd_uppercond = self.time > self.Tsim - 9*self.dt_rk4 + is_diagstep = self.step_rk4_count % self.rk4_per_diag == 0 + + if is_diagstep and (no_cfd_lowercond or no_cfd_uppercond): + real_time_derivative_dict = {'dt_'+k.strip('_fft'): jnp.zeros((self.Ny, self.Nx)) for k in fields.keys()} + self.saver.save_output(real_time_derivative_dict, + step=self.step_diag_count, + t=self.time) + + self.buffer_time.append(self.time) + self.buffer_fields.append(fields) + + cfd_computable = len(self.buffer_time) == 9 + central_buffer_element_is_diagstep = (self.step_rk4_count - 4)%self.rk4_per_diag == 0 + + if cfd_computable and central_buffer_element_is_diagstep: # If the 4rth element of the buffer is an integer, i.e. a time where the derivative should be outputed, then we can compute the time derivative + dt_real_fields_evol = jax.tree_util.tree_map( + lambda *args: sum(a * jnp.fft.ifft2(b).real for a, b in zip(self.coef, args)), + *self.buffer_fields + ) + # Amend key name + for k in list(dt_real_fields_evol.keys()): + new_key = "dt_"+k + new_key = new_key.strip('_fft') + dt_real_fields_evol[new_key] = dt_real_fields_evol.pop(k) + + step_for_output = int((self.step_rk4_count - 4)/self.dt_diag) + self.saver.save_output(dt_real_fields_evol, + step=step_for_output, + t=self.time - 4*self.dt_rk4, + overwrite=True) + return + def _init_callbacks(self): """Read config['callbacks'] and add any desired methods to self.callbacks.""" # Enable inline HDF5 saving of fields in real space if self._save_real: - self.logger.info("Enabling inline HDF5 saving of fields in real space...") + self.logger.info("--> Enabling inline HDF5 saving of fields in real space...") self.callbacks.append(self._save_real_data_callback) # Enable inline HDF5 saving of fields in fourier space if self._save_fft: - self.logger.info("Enabling inline HDF5 saving of fields in fourier space...") + self.logger.info("--> Enabling inline HDF5 saving of fields in fourier space...") self.callbacks.append(self._save_fft_data_callback) # Enable crash checking if self.user_callbacks.get('check_crash', False): - self.logger.info("Enabling crash checking...") + self.logger.info("--> Enabling crash checking...") self.callbacks.append(self._check_crash_callback) def _init_inline_operations(self): @@ -148,11 +213,11 @@ def _init_inline_operations(self): # Enable numerical noise accumulation filter for HW advection if self.eq in ["HW", "mHW", "BHW"]: - self.logger.info("Enabling numerical noise accumulation filter for HW advection...") + self.logger.info("--> Enabling numerical noise accumulation filter for HW advection...") self.inline_operations.append(self._numerical_noise_accumulation_filter) # Enable 2/3 de-aliasing rule if self.user_inline_operations.get('fft_filter', False): - self.logger.info("Enabling 2/3 de-aliasing rule...") + self.logger.info("--> Enabling 2/3 de-aliasing rule...") self.inline_operations.append(self._apply_fft_mask) def _init_inline_display(self):