Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/somd2/runner/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1762,6 +1762,9 @@ def _checkpoint(
)
system.set_property("lambda", lam)

# Delete all frames from the system.
system.delete_all_frames()

# Stream the final system to file.
_sr.stream.save(system, self._filenames[index]["checkpoint"])

Expand Down Expand Up @@ -1796,6 +1799,9 @@ def _checkpoint(
)
system.set_property("lambda", lam)

# Delete all frames from the system.
system.delete_all_frames()

# Stream the checkpoint to file.
_sr.stream.save(system, self._filenames[index]["checkpoint"])

Expand Down
29 changes: 19 additions & 10 deletions src/somd2/runner/_repex.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,9 @@ def _create_dynamics(
else:
mols = system

# Delete an existing trajectory frames.
mols.delete_all_frames()

# Overload the device and lambda value.
dynamics_kwargs["device"] = device
dynamics_kwargs["lambda_value"] = lam
Expand Down Expand Up @@ -721,11 +724,17 @@ def __init__(self, system, config):
output_directory=self._config.output_directory,
)
else:
_logger.debug("Restarting from file")

# Check to see if the simulation is already complete.
time = self._system[0].time()
if time > self._config.runtime - self._config.timestep:
_logger.success(f"Simulation already complete. Exiting.")
_logger.success("Simulation already complete. Exiting.")
_sys.exit(0)
else:
_logger.info(
f"Restarting at time {time}, time remaining = {self._config.runtime - time}"
)

try:
with open(self._repex_state, "rb") as f:
Expand Down Expand Up @@ -827,28 +836,28 @@ def run(self):
else:
cycles = int(ceil(cycles))

if self._config.checkpoint_frequency.value() > 0.0:
# Store the current checkpoint frequency.
checkpoint_frequency = self._config.checkpoint_frequency

if checkpoint_frequency.value() > 0.0:
# Calculate the number of blocks and the remainder time.
frac = (self._config.runtime / self._config.checkpoint_frequency).value()
frac = (self._config.runtime / checkpoint_frequency).value()

# Handle the case where the runtime is less than the checkpoint frequency.
if frac < 1.0:
frac = 1.0
self._config.checkpoint_frequency = str(self._config.runtime)
checkpoint_frequency = self._config.runtime

num_blocks = int(frac)
rem = round(frac - num_blocks, 12)

# Work out the number of repex cycles per block.
frac = (
self._config.checkpoint_frequency.value()
/ self._config.energy_frequency.value()
)
frac = (checkpoint_frequency / self._config.energy_frequency).value()

# Handle the case where the checkpoint frequency is less than the energy frequency.
if frac < 1.0:
frac = 1.0
self._config.checkpoint_frequency = str(self._config.energy_frequency)
checkpoint_frequency = self._config.energy_frequency

# Store the number of repex cycles per block.
cycles_per_checkpoint = int(frac)
Expand Down Expand Up @@ -1035,7 +1044,7 @@ def run(self):
repeat(num_blocks + int(rem > 0)),
repeat(i == cycles - 1),
):
if not result:
if error:
_logger.error(
f"Checkpoint failed for {_lam_sym} = "
f"{self._lambda_values[index]:.5f}: {error}"
Expand Down
19 changes: 13 additions & 6 deletions src/somd2/runner/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,9 @@ def run_window(self, index):
else:
system = self._system.clone()

# Delete an existing trajectory frames.
system.delete_all_frames()

# GPU platform.
if self._is_gpu:
# Get a GPU from the pool.
Expand Down Expand Up @@ -644,22 +647,26 @@ def generate_lam_vals(lambda_base, increment=0.001):
else:
num_energy_neighbours = None

# Store the current checkpoint frequency.
checkpoint_frequency = self._config.checkpoint_frequency

# Store the checkpoint time in nanoseconds.
checkpoint_interval = self._config.checkpoint_frequency.to("ns")
checkpoint_interval = checkpoint_frequency.to("ns")

# Store the start time.
start = _timer()

# Run the simulation, checkpointing in blocks.
if self._config.checkpoint_frequency.value() > 0.0:
if checkpoint_frequency.value() > 0.0:

# Calculate the number of blocks and the remainder time.
frac = (time / self._config.checkpoint_frequency).value()
frac = (time / checkpoint_frequency).value()

# Handle the case where the runtime is less than the checkpoint frequency.
if frac < 1.0:
frac = 1.0
self._config.checkpoint_frequency = f"{time} ps"
checkpoint_frequency = _sr.u(f"{time} ps")
checkpoint_interval = checkpoint_frequency.to("ns")

num_blocks = int(frac)
rem = round(frac - num_blocks, 12)
Expand All @@ -684,7 +691,7 @@ def generate_lam_vals(lambda_base, increment=0.001):
next_frame = self._config.frame_frequency

# Loop until we reach the runtime.
while runtime <= self._config.checkpoint_frequency:
while runtime <= checkpoint_frequency:
# Run the dynamics in blocks of the GCMC frequency.
dynamics.run(
self._config.gcmc_frequency,
Expand Down Expand Up @@ -725,7 +732,7 @@ def generate_lam_vals(lambda_base, increment=0.001):

else:
dynamics.run(
self._config.checkpoint_frequency,
checkpoint_frequency,
energy_frequency=self._config.energy_frequency,
frame_frequency=self._config.frame_frequency,
lambda_windows=lambda_array,
Expand Down
Loading