Source code for pyvbmc.vbmc.create_vbmc_animation
import io
from pathlib import Path
import imageio
import matplotlib.pyplot as plt
import numpy as np
from .vbmc import VBMC
[docs]
def create_vbmc_animation(
vbmc: VBMC,
path: str,
as_frames: bool = False,
suptitle="full",
**kwargs: dict,
):
"""
Create and save a gif animation of a VBMC optimization run.
Parameters
----------
vbmc : VBMC
The optimized VBMC.
path : str
The path where the gif should be saved to.
as_frames: bool, optional
If `True`, saves the animation as individual frames. The filename will
be appended with the frame number. Default `False`.
suptitle: str, optional
What kind of supertitle to print. "full" (the default) means include
the logging action. "iteration" means print only the iteration. "none"
means do not supertitle the figures.
**kwargs : dict, optional
Keyword arguments, passed to ``vp.plot()``.
Raises
------
ValueError
If the ``suptitle`` option is not one of the three supported values.
"""
path = Path(path)
# plot last figure to figure out x_lim and y_lim later
last_figure_axes = np.array(
vbmc.vp.plot(gp=vbmc.gp, **kwargs).axes
).reshape((vbmc.vp.D, vbmc.vp.D))
images = []
gp = None
for i in range(0, len(vbmc.iteration_history["iter"]) + 1):
if i >= len(vbmc.iteration_history["iter"]):
vp = vbmc.vp
else:
vp = vbmc.iteration_history["vp"][i]
if 0 < i < len(vbmc.iteration_history["vp"]) - 2:
previous_gp = vbmc.iteration_history["gp"][i - 1]
gp = vbmc.iteration_history["gp"][i]
# find points that are new in this iteration
# (hacky cause numpy only has 1D set diff)
highlight_data = np.array(
[
i
for i, x in enumerate(gp.X)
if tuple(x) not in set(map(tuple, previous_gp.X))
]
)
else:
highlight_data = None
fig = vp.plot(
highlight_data=highlight_data, plot_data=True, gp=gp, **kwargs
)
# set title of plot accordingly
if suptitle in ("iteration", "full"):
fig.suptitle("PyVBMC iteration {}".format(i))
elif (
suptitle == "full"
and i < len(vbmc.iteration_history["iter"])
and len(vbmc.iteration_history["logging_action"][i]) > 0
):
fig.suptitle(
"PyVBMC iteration {} ({})".format(
i, "".join(vbmc.iteration_history["logging_action"][i])
)
)
elif suptitle != "none":
raise ValueError(f"Unsupported suptitle option {suptitle}.")
if i == len(vbmc.iteration_history["iter"]):
fig.suptitle("PyVBMC final ({} iterations)".format(i - 1))
# make axis limits the same for all figures and subplots
axes = np.array(fig.axes).reshape((vp.D, vp.D))
for r in range(vp.D):
for c in range(vp.D):
axes[r, c].set_xlim(last_figure_axes[r, c].get_xlim())
if r > c:
axes[r, c].set_ylim(last_figure_axes[r, c].get_ylim())
plt.tight_layout()
images.append(_fig_to_img(fig))
# append final iteration multiple times to increase showing length
if i == len(vbmc.iteration_history["iter"]):
for _ in range(4):
images.append(_fig_to_img(fig))
if as_frames:
stem = path.stem
for (i, img) in enumerate(images):
imageio.imsave(path.with_stem(f"{stem}-{i:03}"), img)
else:
imageio.mimsave(path, images, duration=0.5)
return fig
def _fig_to_img(fig):
"""
A private helper function to save a figure as an image array with correct
dimensions.
"""
io_buf = io.BytesIO()
fig.savefig(io_buf, format="raw", dpi=fig.dpi)
io_buf.seek(0)
img_arr = np.reshape(
np.frombuffer(io_buf.getvalue(), dtype=np.uint8),
newshape=(int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1),
)
io_buf.close()
return img_arr