Skip to content

analyse #

Functionality for analysing the results of a presto run.

Functions:

_add_legend_if_labels #

_add_legend_if_labels(ax: Axes, **kwargs: Any) -> None

Add a legend to the axes only if there are labeled artists.

Source code in presto/analyse.py
def _add_legend_if_labels(ax: Axes, **kwargs: Any) -> None:
    """Add a legend to the axes only if there are labeled artists."""
    handles, labels = ax.get_legend_handles_labels()
    if labels:
        ax.legend(**kwargs)

read_errors #

read_errors(
    paths_by_iter: dict[int, Path],
) -> dict[str, dict[int, NDArray[float64]]]

Read all energy and force data from the HDF5 files.

Returns: Dictionary with keys: 'energy_reference', 'energy_predicted', 'energy_differences', 'forces_reference', 'forces_predicted', 'forces_differences'. Each value is a dict mapping iteration number to numpy array.

Source code in presto/analyse.py
def read_errors(
    paths_by_iter: dict[int, Path],
) -> dict[str, dict[int, npt.NDArray[np.float64]]]:
    """Read all energy and force data from the HDF5 files.

    Returns:
        Dictionary with keys: 'energy_reference', 'energy_predicted', 'energy_differences',
        'forces_reference', 'forces_predicted', 'forces_differences'.
        Each value is a dict mapping iteration number to numpy array.
    """

    results: dict[str, dict[int, npt.NDArray[np.float64]]] = {
        "energy_reference": {},
        "energy_predicted": {},
        "energy_differences": {},
        "forces_reference": {},
        "forces_predicted": {},
        "forces_differences": {},
    }

    for i, filepath in paths_by_iter.items():
        with h5py.File(filepath, "r") as f:
            results["energy_reference"][i] = f["energy_reference"][:]
            results["energy_predicted"][i] = f["energy_predicted"][:]
            results["energy_differences"][i] = f["energy_differences"][:]
            results["forces_reference"][i] = f["forces_reference"][:]
            results["forces_predicted"][i] = f["forces_predicted"][:]
            results["forces_differences"][i] = f["forces_differences"][:]
            results["n_atoms"] = f.attrs["n_atoms"]
            results["n_conformers"] = f.attrs["n_conformers"]

    return results

load_force_fields #

load_force_fields(
    paths_by_iter: dict[int, Path],
) -> dict[int, str]

Load the .offxml files from the given paths.

Source code in presto/analyse.py
def load_force_fields(paths_by_iter: dict[int, Path]) -> dict[int, str]:
    """Load the .offxml files from the given paths."""
    return {i: ForceField(p) for i, p in paths_by_iter.items()}

plot_energy_correlation #

plot_energy_correlation(
    fig: Figure,
    ax: Axes,
    reference: dict[int, NDArray[float64]],
    predicted: dict[int, NDArray[float64]],
) -> None

Plot the correlation between reference and predicted values. For forces, convert to the magnitude of the forces.

Source code in presto/analyse.py
def plot_energy_correlation(
    fig: Figure,
    ax: Axes,
    reference: dict[int, npt.NDArray[np.float64]],
    predicted: dict[int, npt.NDArray[np.float64]],
) -> None:
    """Plot the correlation between reference and predicted values. For
    forces, convert to the magnitude of the forces."""

    for i in reference.keys():
        ax.scatter(reference[i], predicted[i], alpha=0.5, label=f"Iteration {i}")
    all_values = np.concatenate(list(reference.values()) + list(predicted.values()))
    min_val = all_values.min()
    max_val = all_values.max()
    ax.plot([min_val, max_val], [min_val, max_val], color="red", linestyle="--")
    ax.set_xlabel("Reference Energy / kcal mol$^{-1}$")
    ax.set_ylabel("Predicted Energy / kcal mol$^{-1}$")
    ax.set_title("Energy Correlation Plot")
    ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")

get_mol_image_with_atom_idxs #

get_mol_image_with_atom_idxs(
    molecule: Molecule, width: int = 300, height: int = 300
) -> Image

Generate a PIL Image of the molecule with atom indices labeled.

Source code in presto/analyse.py
def get_mol_image_with_atom_idxs(
    molecule: Molecule, width: int = 300, height: int = 300
) -> Image.Image:
    """Generate a PIL Image of the molecule with atom indices labeled."""
    molecule_copy = Molecule(molecule)
    molecule_copy._conformers = None

    rdmol = molecule_copy.to_rdkit()

    # Build labels like "C:0", "C:1", "C:2", ...
    atom_labels = {
        atom.GetIdx(): f"{atom.GetSymbol()}:{atom.GetIdx()}"
        for atom in rdmol.GetAtoms()
    }

    drawer = Draw.MolDraw2DCairo(width, height)
    opts = drawer.drawOptions()
    for idx, label in atom_labels.items():
        opts.atomLabels[idx] = label

    Draw.rdMolDraw2D.PrepareAndDrawMolecule(drawer, rdmol)
    drawer.FinishDrawing()

    # Convert PNG bytes to PIL Image
    png_data = drawer.GetDrawingText()
    img = Image.open(io.BytesIO(png_data))

    return img

plot_force_error_by_atom_idx #

plot_force_error_by_atom_idx(
    fig: Figure,
    ax: Axes,
    errors: dict[int, NDArray[float64]],
    mol: Molecule,
) -> None

Plot a seaborn swarmplot of the force errors by atom index.

Source code in presto/analyse.py
def plot_force_error_by_atom_idx(
    fig: Figure,
    ax: Axes,
    errors: dict[int, npt.NDArray[np.float64]],
    mol: Molecule,
) -> None:
    """Plot a seaborn swarmplot of the force errors by atom index."""
    import seaborn as sns

    for iteration, force_errors in errors.items():
        # Create an array of atom indices
        atom_indices = np.arange(len(force_errors)) % mol.n_atoms
        df = pd.DataFrame(
            {
                "atom_index": atom_indices,
                "force_error": np.linalg.norm(force_errors, axis=1),
                "iteration": np.ones_like(atom_indices) * iteration,
            }
        )
        sns.stripplot(
            x="atom_index",
            y="force_error",
            data=df,
            ax=ax,
            label=f"Iteration {iteration}",
            alpha=0.4,
        )

    # Get molecule image
    mol_image = get_mol_image_with_atom_idxs(mol, width=1800, height=600)

    # Create an inset axes above the main plot for the molecule
    from mpl_toolkits.axes_grid1.inset_locator import inset_axes

    ax_inset = inset_axes(
        ax,
        width="400%",
        height="120%",
        loc="upper center",
        bbox_to_anchor=(0, 1.15, 1, 0.3),
        bbox_transform=ax.transAxes,
    )
    ax_inset.imshow(mol_image)
    ax_inset.axis("off")

    ax.set_xlabel("Atom Index")
    ax.set_ylabel("Force Error / kcal mol$^{-1}$ Å$^{-1}$")

    # Deduplicate legend entries
    handles, labels = ax.get_legend_handles_labels()
    by_label = dict(zip(labels, handles, strict=False))
    ax.legend(
        by_label.values(), by_label.keys(), bbox_to_anchor=(1.05, 1), loc="upper left"
    )

plot_error_statistics #

plot_error_statistics(
    fig: Figure,
    axs: NDArray[Any],
    errors: dict[
        Literal["energy_differences", "forces_differences"],
        dict[int, NDArray[float64]],
    ],
) -> None

Plot the error statistics for the energy and force errors.

Source code in presto/analyse.py
def plot_error_statistics(
    fig: Figure,
    axs: npt.NDArray[Any],
    errors: dict[
        Literal["energy_differences", "forces_differences"],
        dict[int, npt.NDArray[np.float64]],
    ],
) -> None:
    """Plot the error statistics for the energy and force errors."""

    axs = axs.flatten()
    plot_distributions_of_errors(fig, axs[0], errors["energy_differences"], "energy")
    plot_distributions_of_errors(fig, axs[1], errors["forces_differences"], "force")
    # Hide the legend in the first plot
    axs[0].legend().set_visible(False)

    # Plot the rmsds of the errors
    plot_rmse_of_errors(fig, axs[2], errors["energy_differences"], "energy")
    plot_rmse_of_errors(fig, axs[3], errors["forces_differences"], "force")

calculate_dihedrals_for_trajectory #

calculate_dihedrals_for_trajectory(
    pdb_path: Path,
    torsions: dict[
        tuple[int, int], tuple[int, int, int, int]
    ],
) -> dict[tuple[int, int, int, int], NDArray[float64]]

Calculate dihedral angles for all torsions across all frames using MDTraj.

Parameters:

  • pdb_path (Path) –

    Path to the PDB trajectory file.

  • torsions (dict[tuple[int, int], tuple[int, int, int, int]]) –

    Dictionary mapping rotatable bonds to torsion atom indices.

Returns:

  • dict[tuple[int, int, int, int], NDArray[float64]]

    Dictionary mapping torsion atom indices to array of dihedral angles (in degrees) for each frame.

Source code in presto/analyse.py
def calculate_dihedrals_for_trajectory(
    pdb_path: Path,
    torsions: dict[tuple[int, int], tuple[int, int, int, int]],
) -> dict[tuple[int, int, int, int], npt.NDArray[np.float64]]:
    """Calculate dihedral angles for all torsions across all frames using MDTraj.

    Parameters
    ----------
    pdb_path : Path
        Path to the PDB trajectory file.
    torsions : dict[tuple[int, int], tuple[int, int, int, int]]
        Dictionary mapping rotatable bonds to torsion atom indices.

    Returns
    -------
    dict[tuple[int, int, int, int], npt.NDArray[np.float64]]
        Dictionary mapping torsion atom indices to array of dihedral angles (in degrees)
        for each frame.
    """

    trajectory = mdtraj.load(str(pdb_path))
    dihedrals = {}

    for torsion_atoms in torsions.values():
        # MDTraj expects indices as a 2D array with shape (n_dihedrals, 4)
        indices = np.array([torsion_atoms])
        # compute_dihedrals returns angles in radians
        angles_rad = mdtraj.compute_dihedrals(trajectory, indices)
        # Convert to degrees and flatten (we only have one dihedral)
        dihedrals[torsion_atoms] = np.degrees(angles_rad.flatten())

    return dihedrals

plot_torsion_dihedrals #

plot_torsion_dihedrals(
    fig: Figure,
    axs: NDArray[Any],
    dihedrals_by_iteration: dict[
        int,
        dict[tuple[int, int, int, int], NDArray[float64]],
    ],
    mol: Molecule,
) -> None

Plot dihedral angles for all rotatable torsions during trajectories.

Each torsion gets its own subplot.

Parameters:

  • fig (Figure) –

    Matplotlib figure.

  • axs (NDArray[Any]) –

    Array of matplotlib axes (one for each torsion).

  • dihedrals_by_iteration (dict) –

    Dictionary mapping iteration to dictionary of torsion dihedrals. Inner dict maps torsion atom indices to array of dihedral angles.

  • mol (Molecule) –

    The molecule being analyzed.

Source code in presto/analyse.py
def plot_torsion_dihedrals(
    fig: Figure,
    axs: npt.NDArray[Any],
    dihedrals_by_iteration: dict[
        int, dict[tuple[int, int, int, int], npt.NDArray[np.float64]]
    ],
    mol: Molecule,
) -> None:
    """Plot dihedral angles for all rotatable torsions during trajectories.

    Each torsion gets its own subplot.

    Parameters
    ----------
    fig : Figure
        Matplotlib figure.
    axs : npt.NDArray[Any]
        Array of matplotlib axes (one for each torsion).
    dihedrals_by_iteration : dict
        Dictionary mapping iteration to dictionary of torsion dihedrals.
        Inner dict maps torsion atom indices to array of dihedral angles.
    mol : Molecule
        The molecule being analyzed.
    """
    # Get molecule image with atom indices
    mol_image = get_mol_image_with_atom_idxs(mol, width=1800, height=600)

    # Create an inset axes above the main plot for the molecule
    from mpl_toolkits.axes_grid1.inset_locator import inset_axes

    # Place molecule image above the first subplot
    ax_inset = inset_axes(
        axs.flat[0],
        width="400%",
        height="120%",
        loc="upper center",
        bbox_to_anchor=(0, 1.15, 1, 0.3),
        bbox_transform=axs.flat[0].transAxes,
    )
    ax_inset.imshow(mol_image)
    ax_inset.axis("off")

    # Get all unique torsions across all iterations
    all_torsions: set[tuple[int, int, int, int]] = set()
    for dihedrals in dihedrals_by_iteration.values():
        all_torsions.update(dihedrals.keys())
    all_torsions_list = sorted(all_torsions)

    # Plot dihedrals for each iteration
    colours = plt.colormaps["viridis"](
        np.linspace(0, 1, len(dihedrals_by_iteration) + 1)
    )

    # Create one subplot per torsion
    for torsion_idx, torsion_atoms in enumerate(all_torsions_list):
        ax = axs.flat[torsion_idx]

        # Collect angles by iteration
        angles_by_iteration = {}

        for iteration_idx, (iteration, dihedrals) in enumerate(
            dihedrals_by_iteration.items()
        ):
            if torsion_atoms in dihedrals:
                angles = dihedrals[torsion_atoms]
                angles_by_iteration[iteration] = (angles, iteration_idx)

                # Create frame numbers as x-axis
                frames = np.arange(len(angles))

                # Label with iteration
                label = f"Iteration {iteration}"

                ax.scatter(
                    frames,
                    angles,
                    label=label,
                    alpha=0.5,
                    color=colours[iteration_idx],
                    s=10,
                )

        ax.set_xlabel("Frame Number")
        ax.set_ylabel("Dihedral Angle / degrees")
        ax.set_title(
            f"Torsion [{torsion_atoms[0]}-{torsion_atoms[1]}-{torsion_atoms[2]}-{torsion_atoms[3]}]"
        )
        ax.axhline(y=0, color="gray", linestyle="--", alpha=0.3, linewidth=0.5)
        ax.axhline(y=180, color="gray", linestyle="--", alpha=0.3, linewidth=0.5)
        ax.axhline(y=-180, color="gray", linestyle="--", alpha=0.3, linewidth=0.5)
        ax.set_ylim(-180, 180)

        # Add legend
        _add_legend_if_labels(ax, loc="best", fontsize="small")

        # Add histogram as inset on the right side
        if angles_by_iteration:
            ax_hist = inset_axes(
                ax,
                width="20%",
                height="100%",
                loc="center right",
                bbox_to_anchor=(0.15, 0, 1, 1),
                bbox_transform=ax.transAxes,
            )

            # Create histogram for each iteration with matching colors
            for _iteration, (angles, iteration_idx) in angles_by_iteration.items():
                ax_hist.hist(
                    angles,
                    bins=36,
                    orientation="horizontal",
                    range=(-180, 180),
                    alpha=0.5,
                    color=colours[iteration_idx],
                    edgecolor="black",
                    linewidth=0.3,
                )

            ax_hist.set_ylim(-180, 180)
            ax_hist.set_xlabel("Count", fontsize=8)
            ax_hist.tick_params(axis="both", labelsize=7)
            ax_hist.yaxis.set_visible(False)

    # Hide unused subplots
    for idx in range(len(all_torsions), len(axs.flat)):
        axs.flat[idx].axis("off")

analyse_workflow #

analyse_workflow(
    workflow_settings: WorkflowSettings,
) -> None

Analyse the results of a presto workflow.

Source code in presto/analyse.py
def analyse_workflow(workflow_settings: WorkflowSettings) -> None:
    """Analyse the results of a presto workflow."""

    mols = workflow_settings.parameterisation_settings.molecules

    # Suppress matplotlib categorical units warning by setting logger level
    import logging

    logging.getLogger("matplotlib.category").setLevel(logging.ERROR)

    with plt.style.context(PLT_STYLE):
        # Plot the losses
        path_manager = workflow_settings.get_path_manager()
        stage = OutputStage(StageKind.PLOTS)
        path_manager.mk_stage_dir(stage)

        output_paths_by_type = path_manager.get_all_output_paths_by_output_type()
        output_paths_by_type_by_mol = (
            path_manager.get_all_output_paths_by_output_type_by_molecule()
        )

        training_metric_paths = dict(
            enumerate(output_paths_by_type[OutputType.TRAINING_METRICS])
        )
        losses = read_losses(training_metric_paths)
        fig, ax = plt.subplots(figsize=(10, 6))
        plot_loss(fig, ax, losses)
        fig.savefig(
            str(path_manager.get_output_path(stage, OutputType.LOSS_PLOT)),
            dpi=300,
            bbox_inches="tight",
        )
        plt.close(fig)

        # Get scatter paths organized by molecule
        scatter_paths_by_mol = output_paths_by_type_by_mol.get(OutputType.SCATTER, {})
        assert isinstance(scatter_paths_by_mol, dict)

        # Plot for each molecule
        for mol_idx, mol in enumerate(mols):
            if mol_idx not in scatter_paths_by_mol:
                logger.warning(f"No scatter paths found for molecule {mol_idx}")
                continue

            # Convert list of paths to dict indexed by iteration
            scatter_paths_for_mol = dict(enumerate(scatter_paths_by_mol[mol_idx]))
            errors = read_errors(scatter_paths_for_mol)

            # Plot the errors
            fig, axs = plt.subplots(2, 2, figsize=(13, 12))
            plot_error_statistics(fig, axs, errors)  # type: ignore[arg-type]
            error_plot_path = path_manager.get_output_path(stage, OutputType.ERROR_PLOT)
            error_plot_path_mol = get_mol_path(error_plot_path, mol_idx)
            fig.savefig(str(error_plot_path_mol), dpi=300, bbox_inches="tight")
            plt.close(fig)

            # Plot the correlation plots
            fig, ax = plt.subplots(1, 1, figsize=(6.5, 6))
            plot_energy_correlation(
                fig,
                ax,
                errors["energy_reference"],
                errors["energy_predicted"],
            )
            corr_plot_path = path_manager.get_output_path(
                stage, OutputType.CORRELATION_PLOT
            )
            corr_plot_path_mol = get_mol_path(corr_plot_path, mol_idx)
            fig.savefig(str(corr_plot_path_mol), dpi=300, bbox_inches="tight")
            plt.close(fig)

            # Plot the force error by atom index
            fig, ax = plt.subplots(1, 1, figsize=(0.5 * mol.n_atoms, 6))
            plot_force_error_by_atom_idx(fig, ax, errors["forces_differences"], mol)
            force_error_plot_path = path_manager.get_output_path(
                stage, OutputType.FORCE_ERROR_BY_ATOM_INDEX_PLOT
            )
            force_error_plot_path_mol = get_mol_path(force_error_plot_path, mol_idx)
            fig.savefig(str(force_error_plot_path_mol), dpi=300, bbox_inches="tight")
            plt.close(fig)

            # Plot torsion dihedrals if trajectory files exist
            pdb_traj_paths_by_mol = output_paths_by_type_by_mol.get(
                OutputType.PDB_TRAJECTORY, {}
            )
            if mol_idx in pdb_traj_paths_by_mol:
                # Get rotatable torsions for this molecule
                torsions = get_rot_torsions_by_rot_bond(
                    mol,
                    include_smarts=DEFAULT_TORSIONS_TO_INCLUDE_SMARTS,
                    exclude_smarts=DEFAULT_TORSIONS_TO_EXCLUDE_SMARTS,
                )

                if torsions:
                    # Read trajectories and calculate dihedrals for each iteration
                    dihedrals_by_iteration = {}
                    pdb_paths_for_mol = pdb_traj_paths_by_mol[mol_idx]
                    if not isinstance(pdb_paths_for_mol, list):
                        pdb_paths_for_mol = [pdb_paths_for_mol]

                    # Include initial statistics (iteration 0) and training iterations
                    for iter_idx, pdb_path in enumerate(pdb_paths_for_mol):
                        if pdb_path.exists():
                            try:
                                dihedrals = calculate_dihedrals_for_trajectory(
                                    pdb_path, torsions
                                )
                                dihedrals_by_iteration[iter_idx] = dihedrals
                            except Exception as e:
                                logger.warning(
                                    f"Failed to read trajectory at {pdb_path}: {e}"
                                )

                    if dihedrals_by_iteration:
                        # Determine figure layout based on number of torsions
                        n_torsions = len(torsions)
                        # Create a grid layout - 2 columns for better layout
                        ncols = min(2, n_torsions)
                        nrows = (n_torsions + ncols - 1) // ncols
                        fig, axs = plt.subplots(
                            nrows, ncols, figsize=(8 * ncols, 5 * nrows), squeeze=False
                        )
                        plot_torsion_dihedrals(fig, axs, dihedrals_by_iteration, mol)
                        torsion_plot_path = path_manager.get_output_path(
                            stage, OutputType.TORSION_DIHEDRALS_PLOT
                        )
                        torsion_plot_path_mol = get_mol_path(torsion_plot_path, mol_idx)
                        fig.tight_layout()
                        fig.savefig(
                            str(torsion_plot_path_mol), dpi=300, bbox_inches="tight"
                        )
                        plt.close(fig)

        # Plot the force field changes for each molecule
        offxml_paths = output_paths_by_type.get(OutputType.OFFXML, [])
        assert isinstance(offxml_paths, list)
        ff_paths = load_force_fields(dict(enumerate(offxml_paths)))

        for mol_idx, mol in enumerate(mols):
            fig, axs = plot_all_ffs(ff_paths, mol, "values")
            param_values_plot_path = path_manager.get_output_path(
                stage, OutputType.PARAMETER_VALUES_PLOT
            )
            param_values_plot_path_mol = get_mol_path(param_values_plot_path, mol_idx)
            fig.savefig(str(param_values_plot_path_mol), dpi=300, bbox_inches="tight")
            plt.close(fig)

            fig, axs = plot_all_ffs(ff_paths, mol, "differences")
            param_diff_plot_path = path_manager.get_output_path(
                stage, OutputType.PARAMETER_DIFFERENCES_PLOT
            )
            param_diff_plot_path_mol = get_mol_path(param_diff_plot_path, mol_idx)
            fig.savefig(str(param_diff_plot_path_mol), dpi=300, bbox_inches="tight")
            plt.close(fig)