Skip to content

workflow #

Implements the overall workflow for fitting a bespoke force field.

Functions:

get_bespoke_force_field #

get_bespoke_force_field(
    settings: WorkflowSettings, write_settings: bool = True
) -> ForceField

Fit a bespoke force field. This involves:

  • Parameterising a base force field for the target molecule and generating specific tagged SMARTS parameters
  • Generating training data (e.g. from high-temperature MD simulations)
  • Optimising the parameters of the force field to reproduce the training data
  • Validating the fitted force field against test data

Parameters:

  • settings (WorkflowSettings) –

    The workflow settings to use for fitting the force field.

  • write_settings (bool, default: True ) –

    Whether to write the settings to a YAML file in the output directory, by default True.

Returns:

  • ForceField

    The fitted bespoke force field.

Source code in presto/workflow.py
def get_bespoke_force_field(
    settings: WorkflowSettings, write_settings: bool = True
) -> ForceField:
    """
    Fit a bespoke force field. This involves:

    - Parameterising a base force field for the target molecule and generating
      specific tagged SMARTS parameters
    - Generating training data (e.g. from high-temperature MD simulations)
    - Optimising the parameters of the force field to reproduce the training data
    - Validating the fitted force field against test data

    Parameters
    ----------
    settings : WorkflowSettings
        The workflow settings to use for fitting the force field.

    write_settings : bool, optional
        Whether to write the settings to a YAML file in the output directory, by default True.

    Returns
    -------
    ForceField
        The fitted bespoke force field.
    """
    path_manager = settings.get_path_manager()
    stage = OutputStage(StageKind.BASE)
    path_manager.mk_stage_dir(stage)

    if write_settings:
        settings_output_path = path_manager.get_output_path(
            stage, OutputType.WORKFLOW_SETTINGS
        )
        logger.info(f"Writing workflow settings to {settings_output_path}.")
        # Copy the settings and change the output directory to be "." as we save
        # to the output directory already
        output_settings = copy.deepcopy(settings)
        output_settings.output_dir = pathlib.Path(".")
        output_settings.to_yaml(settings_output_path)

    # Parameterise the base force field for all molecules
    off_mols, initial_off_ff, tensor_tops, tensor_ff = parameterise(
        settings.parameterisation_settings, device=settings.device_type
    )

    pruned_parameter_configs = {
        p_type: p_config
        for p_type, p_config in settings.training_settings.parameter_configs.items()
        if p_type in tensor_ff.potentials_by_type
    }

    trainable = Trainable(
        tensor_ff,
        pruned_parameter_configs,
        settings.training_settings.attribute_configs,
    )

    trainable_parameters = trainable.to_values().to((settings.device))

    # Get a copy of the initial trainable parameters for regularisation
    initial_parameters = trainable_parameters.clone().detach()

    # Generate the test data for all molecules
    stage = OutputStage(StageKind.TESTING)
    path_manager.mk_stage_dir(stage)
    test_sample_fn: SampleFn = _SAMPLING_FNS_REGISTRY[
        type(settings.testing_sampling_settings)
    ]
    logger.info("Generating test data")
    datasets_test = test_sample_fn(
        mols=off_mols,
        off_ff=initial_off_ff,
        device=settings.device,
        settings=settings.testing_sampling_settings,
        output_paths={
            output_type: path_manager.get_output_path(stage, output_type)
            for output_type in settings.testing_sampling_settings.output_types
        },
    )

    if test_sample_fn is not load_precomputed_dataset:  # type: ignore[comparison-overlap]
        for mol_idx, dataset_test in enumerate(datasets_test):
            dataset_path_mol = path_manager.get_output_path_for_mol(
                stage, OutputType.ENERGIES_AND_FORCES, mol_idx
            )
            dataset_test.save_to_disk(str(dataset_path_mol))

    # Write out statistics on the initial force field
    stage = OutputStage(StageKind.INITIAL_STATISTICS)
    path_manager.mk_stage_dir(stage)

    # Write scatter plots for each molecule
    for mol_idx, (dataset_test, tensor_top) in enumerate(
        zip(datasets_test, tensor_tops, strict=True)
    ):
        scatter_path_mol = path_manager.get_output_path_for_mol(
            stage, OutputType.SCATTER, mol_idx
        )
        energy_mean, energy_sd, forces_mean, forces_sd = write_scatter(
            dataset_test,
            tensor_ff,
            tensor_top,
            str(settings.device),
            str(scatter_path_mol),
        )
        logger.info(
            f"Molecule {mol_idx} initial force field statistics: Energy (Mean/SD): {energy_mean:.3e}/{energy_sd:.3e} kcal/mol, Forces (Mean/SD): {forces_mean:.3e}/{forces_sd:.3e} kcal/mol/Å"
        )

    off_ff = convert_to_smirnoff(
        trainable.to_force_field(trainable_parameters), base=initial_off_ff
    )
    off_ff.to_file(str(path_manager.get_output_path(stage, OutputType.OFFXML)))

    train_sample_fn = _SAMPLING_FNS_REGISTRY[type(settings.training_sampling_settings)]

    train_fn = _TRAINING_FNS_REGISTRY[settings.training_settings.optimiser]

    # Train the force field
    progress = Progress(
        TextColumn("[progress.description]{task.description}"),
        BarColumn(),
        MofNCompleteColumn(),
        TimeRemainingColumn(),
    )

    datasets_train = None  # Only None for the first iteration

    with progress:
        for iteration in progress.track(
            range(1, settings.n_iterations + 1),  # Start from 1 (0 is untrained)
            description="Iterating the Fit",
        ):
            stage = OutputStage(StageKind.TRAINING, iteration)
            path_manager.mk_stage_dir(stage)

            datasets_train_new = train_sample_fn(
                mols=off_mols,
                off_ff=off_ff,
                device=settings.device,
                settings=settings.training_sampling_settings,
                output_paths={
                    output_type: path_manager.get_output_path(stage, output_type)
                    for output_type in settings.training_sampling_settings.output_types
                },
            )

            # Apply outlier filtering if configured
            if settings.outlier_filter_settings is not None:
                logger.info("Applying outlier filtering to training data")
                datasets_train_new = [
                    filter_dataset_outliers(
                        dataset=ds,
                        force_field=tensor_ff,
                        topology=tensor_top,
                        settings=settings.outlier_filter_settings,
                        device=str(settings.device),
                    )
                    for ds, tensor_top in zip(
                        datasets_train_new, tensor_tops, strict=True
                    )
                ]

            # Update training dataset: concatenate if memory is enabled and not the first iteration
            if settings.memory and datasets_train is not None:
                datasets_train = [
                    datasets.combine.concatenate_datasets([ds_old, ds_new])
                    for ds_old, ds_new in zip(
                        datasets_train, datasets_train_new, strict=True
                    )
                ]
            else:
                datasets_train = datasets_train_new

            # Save each dataset
            if train_sample_fn is not load_precomputed_dataset:  # type: ignore[comparison-overlap]
                for mol_idx, dataset_train in enumerate(datasets_train):
                    dataset_path_mol = path_manager.get_output_path_for_mol(
                        stage, OutputType.ENERGIES_AND_FORCES, mol_idx
                    )
                    dataset_train.save_to_disk(str(dataset_path_mol))

            train_output_paths = {
                output_type: path_manager.get_output_path(stage, output_type)
                for output_type in settings.training_settings.output_types
            }

            trainable_parameters, trainable = train_fn(
                trainable_parameters=trainable_parameters,
                initial_parameters=initial_parameters,
                trainable=trainable,
                topologies=tensor_tops,
                datasets=datasets_train,
                datasets_test=datasets_test,
                settings=settings.training_settings,
                output_paths=train_output_paths,
                device=settings.device,
            )

            for potential_type in trainable._param_types:
                tensor_ff.potentials_by_type[potential_type].parameters = copy.copy(
                    trainable.to_force_field(trainable_parameters)
                    .potentials_by_type[potential_type]
                    .parameters
                )

            off_ff = convert_to_smirnoff(
                trainable.to_force_field(trainable_parameters), base=initial_off_ff
            )
            off_ff.to_file(str(path_manager.get_output_path(stage, OutputType.OFFXML)))

            # Write scatter plots for each molecule
            for mol_idx, (dataset_test, tensor_top) in enumerate(
                zip(datasets_test, tensor_tops, strict=True)
            ):
                scatter_path_mol = path_manager.get_output_path_for_mol(
                    stage, OutputType.SCATTER, mol_idx
                )
                energy_mean_new, energy_sd_new, forces_mean_new, forces_sd_new = (
                    write_scatter(
                        dataset_test,
                        tensor_ff,
                        tensor_top,
                        str(settings.device),
                        str(scatter_path_mol),
                    )
                )
                logger.info(
                    f"Iteration {iteration} Molecule {mol_idx} force field statistics: Energy (Mean/SD): {energy_mean_new:.3e}/{energy_sd_new:.3e} kcal/mol, Forces (Mean/SD): {forces_mean_new:.3e}/{forces_sd_new:.3e} kcal/mol/Å"
                )

    # Plot
    analyse_workflow(settings)

    return off_ff