Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/interfaces/input_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,12 @@ def _set_default(self):
("source_gauss_sigma_x" , "source" , [0.]),
("source_gauss_sigma_y" , "source" , [0.]),
# Inline operations
("fft_filter" , "inline_operations", True),
("fft_filter" , "inline_operations", True),
("compute_time_derivatives" , "inline_operations", False),
# Callbacks
("check_crash" , "callbacks", True),
("check_crash" , "callbacks", True),
("save_real" , "callbacks", True),
("save_fft" , "callbacks", False),
("save_fft" , "callbacks", False),
]

def default(key, category, default_value):
Expand Down
5 changes: 3 additions & 2 deletions src/interfaces/output_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,12 @@ def _make_output_folder(self, output_folder):
Path(output_folder).mkdir(parents=True, exist_ok=True)
return Path(output_folder)

def save_output(self, fields, step, t):
def save_output(self, fields, step, t, overwrite=False):
with h5.File(self.output_folder / f'fields_{step:05d}.h5', 'a') as f:
for field in fields:
if overwrite and (field in f):
del f[field]
f.create_dataset(field, data=fields[field])
# f.create_dataset('time', data=t)
f.require_dataset('time', shape=(), dtype='f', data=t)

def save_output_data_file(self, erase_files=True):
Expand Down
75 changes: 70 additions & 5 deletions src/simulation/run_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,19 @@ def __init__(self, params, saver=None):
self.Nt_rk4 = params.user["time"]["Nt_rk4"]
self.Nt_diag = params.user["time"]["Nt_diag"]
self.dt_rk4 = params.user["time"]["dt_rk4"]
self.dt_diag = params.user["time"]["dt_diag"]
self.rk4_per_diag = params.user["time"]["rk4_per_diag"]
self.Tsim = params.user["time"]["Tsim"]

self.time = 0.0
self.step_rk4_count = 0
self.step_diag_count = 0
self._stop = False

# Retrieve mesh
self.Nx = params.user["grid"]["Nx"]
self.Ny = params.user["grid"]["Ny"]

# Retrieve and initialize the operation occuring during the loop
self.user_callbacks = params.user.get("callbacks", {})
self.user_inline_operations = params.user.get("inline_operations", {})
Expand All @@ -98,6 +104,7 @@ def __init__(self, params, saver=None):
self._init_callbacks()
self._init_inline_operations()
self._init_inline_display()
self._init_inline_compute_time_derivative()

def run(self, fields):
"""
Expand All @@ -106,6 +113,9 @@ def run(self, fields):
"""
while self.step_rk4_count < self.Nt_rk4 and not self._stop:

if self.user_inline_operations.get('compute_time_derivatives', False):
self.inline_compute_time_derivative(fields)

if self.step_rk4_count % self.rk4_per_diag == 0:
# Trigger callbacks (e.g. saving, crash checking) at diagnostic time step
for cb in self.callbacks:
Expand All @@ -121,38 +131,93 @@ def run(self, fields):

self.step_diag_count += 1
fields = step(fields, self.pde, self.dt_rk4, t=self.time)

self.time += self.dt_rk4
self.step_rk4_count += 1

return fields

def _init_inline_compute_time_derivative(self):
if not self.user_inline_operations.get('compute_time_derivatives', False):
return
self.logger.info("--> Enabling inline computation of time derivatives...")

# Efficient buffers: store 9 last time steps to estimate time derivative using central finite differences
from collections import deque
self.buffer_time = deque(maxlen=9)
self.buffer_fields = deque(maxlen=9)

# Coefficients for 8th order central finite differences
coef = [1./280., -4./105., 1./5., -4./5., 0., 4./5., -1./5., 4./105., -1./280.]
self.coef = [x/self.dt_rk4 for x in coef] # Apply the division by the time step h

return

def inline_compute_time_derivative(self, fields):
# Central FD cannot be computed for firsts and lasts time steps,
# So until custom stencils are implemented, we set the derivative to 0
no_cfd_lowercond = self.time < 9*self.dt_rk4
no_cfd_uppercond = self.time > self.Tsim - 9*self.dt_rk4
is_diagstep = self.step_rk4_count % self.rk4_per_diag == 0

if is_diagstep and (no_cfd_lowercond or no_cfd_uppercond):
real_time_derivative_dict = {'dt_'+k.strip('_fft'): jnp.zeros((self.Ny, self.Nx)) for k in fields.keys()}
self.saver.save_output(real_time_derivative_dict,
step=self.step_diag_count,
t=self.time)

self.buffer_time.append(self.time)
self.buffer_fields.append(fields)

cfd_computable = len(self.buffer_time) == 9
central_buffer_element_is_diagstep = (self.step_rk4_count - 4)%self.rk4_per_diag == 0

if cfd_computable and central_buffer_element_is_diagstep: # If the 4rth element of the buffer is an integer, i.e. a time where the derivative should be outputed, then we can compute the time derivative
dt_real_fields_evol = jax.tree_util.tree_map(
lambda *args: sum(a * jnp.fft.ifft2(b).real for a, b in zip(self.coef, args)),
*self.buffer_fields
)
# Amend key name
for k in list(dt_real_fields_evol.keys()):
new_key = "dt_"+k
new_key = new_key.strip('_fft')
dt_real_fields_evol[new_key] = dt_real_fields_evol.pop(k)

step_for_output = int((self.step_rk4_count - 4)/self.dt_diag)
self.saver.save_output(dt_real_fields_evol,
step=step_for_output,
t=self.time - 4*self.dt_rk4,
overwrite=True)
return

def _init_callbacks(self):
"""Read config['callbacks'] and add any desired methods to self.callbacks."""

# Enable inline HDF5 saving of fields in real space
if self._save_real:
self.logger.info("Enabling inline HDF5 saving of fields in real space...")
self.logger.info("--> Enabling inline HDF5 saving of fields in real space...")
self.callbacks.append(self._save_real_data_callback)

# Enable inline HDF5 saving of fields in fourier space
if self._save_fft:
self.logger.info("Enabling inline HDF5 saving of fields in fourier space...")
self.logger.info("--> Enabling inline HDF5 saving of fields in fourier space...")
self.callbacks.append(self._save_fft_data_callback)

# Enable crash checking
if self.user_callbacks.get('check_crash', False):
self.logger.info("Enabling crash checking...")
self.logger.info("--> Enabling crash checking...")
self.callbacks.append(self._check_crash_callback)

def _init_inline_operations(self):
"""Read config['inline_operation'] and add any desired methods to self.inline_operation."""

# Enable numerical noise accumulation filter for HW advection
if self.eq in ["HW", "mHW", "BHW"]:
self.logger.info("Enabling numerical noise accumulation filter for HW advection...")
self.logger.info("--> Enabling numerical noise accumulation filter for HW advection...")
self.inline_operations.append(self._numerical_noise_accumulation_filter)
# Enable 2/3 de-aliasing rule
if self.user_inline_operations.get('fft_filter', False):
self.logger.info("Enabling 2/3 de-aliasing rule...")
self.logger.info("--> Enabling 2/3 de-aliasing rule...")
self.inline_operations.append(self._apply_fft_mask)

def _init_inline_display(self):
Expand Down