Skip to content

train #

Apply OpenFF parameters to molecule, cluster conformers by RMSD and train

Classes:

Functions:

_TRAINING_FNS_REGISTRY module-attribute #

_TRAINING_FNS_REGISTRY: dict[OptimiserName, TrainFn] = {}

Registry of training functions for different optimiser names.

TrainingFnArgs #

Bases: TypedDict

Arguments for training functions.

TrainFn #

Bases: Protocol

A protocol for training functions.

train_levenberg_marquardt #

train_levenberg_marquardt(
    trainable_parameters: Tensor,
    initial_parameters: Tensor,
    trainable: Trainable,
    topologies: list[TensorTopology],
    datasets: list[Dataset],
    datasets_test: list[Dataset],
    settings: TrainingSettings,
    output_paths: dict[OutputType, PathLike],
    device: device,
) -> tuple[Tensor, Trainable]

Iterate the training process using the Levenberg-Marquardt algorithm.

Returns:

  • tuple[torch.Tensor, descent.train.Trainable]

    The updated parameters and the trainable object.

Source code in presto/train.py
@_register_training_fn("lm")
def train_levenberg_marquardt(
    trainable_parameters: torch.Tensor,
    initial_parameters: torch.Tensor,
    trainable: descent.train.Trainable,
    topologies: list[smee.TensorTopology],
    datasets: list[datasets.Dataset],
    datasets_test: list[datasets.Dataset],
    settings: TrainingSettings,
    output_paths: dict[OutputType, PathLike],
    device: torch.device,
) -> tuple[torch.Tensor, descent.train.Trainable]:
    """
    Iterate the training process using the Levenberg-Marquardt algorithm.

    Parameters
    ----------
        trainable_parameters: torch.Tensor
            The parameters to be optimized.
        initial_parameters: torch.Tensor
            The initial parameters before training.
        trainable: descent.train.Trainable
            The trainable object containing the parameters.
        topologies: list[smee.TensorTopology]
            The topologies of the systems.
        datasets: list[datasets.Dataset]
            The datasets to be used for training.
        datasets_test: list[datasets.Dataset]
            The datasets to be used for testing.
        settings: TrainingSettings
            The settings object containing training parameters.
        output_paths: dict[OutputType, PathLike]
            A mapping of output types to filesystem paths. The following keys are
            expected:
                - OutputType.TENSORBOARD
                - OutputType.TRAINING_METRICS
        device: torch.device
            The device to perform training on.

    Returns
    -------
        tuple[torch.Tensor, descent.train.Trainable]
            The updated parameters and the trainable object.
    """
    # Warn the user that LM needs more testing
    logger.warning(
        "Levenberg-Marquardt optimisation is an experimental feature and may not "
        "work as expected."
    )

    # Make sure we have all the required output paths and no others
    if set(output_paths.keys()) != settings.output_types:
        raise ValueError(
            f"Output paths must contain exactly the keys {settings.output_types}"
        )

    # Run the training with the LM optimiser
    lm_config = descent.optim.LevenbergMarquardtConfig(
        mode="adaptive", n_convergence_criteria=2, max_steps=settings.n_epochs
    )

    closure_fn = get_loss_closure_fn(
        datasets,
        trainable,
        trainable_parameters,
        initial_parameters,
        topologies,
        settings.regularisation_target,
    )

    correct_fn = trainable.clamp

    # Create report function that computes metrics consistently with train_adam
    report_fn = functools.partial(
        report,
        trainable=trainable,
        topologies=topologies,
        datasets_train=datasets,
        datasets_test=datasets_test,
        initial_parameters=initial_parameters,
        regularisation_target=settings.regularisation_target,
        metrics_file=output_paths[OutputType.TRAINING_METRICS],
        experiment_dir=Path(output_paths[OutputType.TENSORBOARD]),
    )

    trainable_parameters = descent.optim.levenberg_marquardt(
        trainable_parameters, lm_config, closure_fn, correct_fn, report_fn
    )
    trainable_parameters.requires_grad_(True)

    return trainable_parameters, trainable

train_adam #

train_adam(
    trainable_parameters: Tensor,
    initial_parameters: Tensor,
    trainable: Trainable,
    topologies: list[TensorTopology],
    datasets: list[Dataset],
    datasets_test: list[Dataset],
    settings: TrainingSettings,
    output_paths: dict[OutputType, PathLike],
    device: device,
) -> tuple[Tensor, Trainable]

Iterate the training process using the Adam optimizer.

Returns:

  • tuple[torch.Tensor, descent.train.Trainable]

    The updated parameters and the trainable object.

Source code in presto/train.py
@_register_training_fn("adam")
def train_adam(
    trainable_parameters: torch.Tensor,
    initial_parameters: torch.Tensor,
    trainable: descent.train.Trainable,
    topologies: list[smee.TensorTopology],
    datasets: list[datasets.Dataset],
    datasets_test: list[datasets.Dataset],
    settings: TrainingSettings,
    output_paths: dict[OutputType, PathLike],
    device: torch.device,
) -> tuple[torch.Tensor, descent.train.Trainable]:
    """
    Iterate the training process using the Adam optimizer.

    Parameters
    ----------
        trainable_parameters: torch.Tensor
            The parameters to be optimized.
        initial_parameters: torch.Tensor
            The initial parameters before training.
        trainable: descent.train.Trainable
            The trainable object containing the parameters.
        topologies: list[smee.TensorTopology]
            The topologies of the systems.
        datasets: list[datasets.Dataset]
            The datasets to be used for training.
        datasets_test: list[datasets.Dataset]
            The datasets to be used for testing.
        settings: TrainingSettings
            The settings object containing training parameters.
        output_paths: dict[OutputType, PathLike]
            A mapping of output types to filesystem paths. The following keys are
            expected:
                - OutputType.TENSORBOARD
                - OutputType.TRAINING_METRICS
        device: torch.device
            The device to perform training on.

    Returns
    -------
        tuple[torch.Tensor, descent.train.Trainable]
            The updated parameters and the trainable object.
    """
    # Make sure we have all the required output paths and no others
    if set(output_paths.keys()) != settings.output_types:
        raise ValueError(
            f"Output paths must contain exactly the keys {settings.output_types}"
        )

    # run the ML training

    with open(output_paths[OutputType.TRAINING_METRICS], "w") as metrics_file:
        with open_writer(Path(output_paths[OutputType.TENSORBOARD])) as writer:
            optimizer = torch.optim.Adam(
                [trainable_parameters], lr=settings.learning_rate, amsgrad=True
            )
            scheduler = torch.optim.lr_scheduler.ExponentialLR(
                optimizer, gamma=settings.learning_rate_decay
            )
            for v in tensorboardX.writer.hparams(
                {"optimizer": "Adam", "lr": settings.learning_rate}, {}
            ):
                writer.file_writer.add_summary(v)
            progress = Progress(
                TextColumn("[progress.description]{task.description}"),
                BarColumn(),
                MofNCompleteColumn(),
                TimeRemainingColumn(),
            )
            with progress:
                for i in progress.track(
                    range(settings.n_epochs),
                    description="Optimising MM parameters",
                ):
                    # Memory optimisation: Use manual gradient computation which
                    # computes gradients per-molecule and detaches immediately
                    losses_train, gradient = compute_overall_loss_and_grad(
                        datasets,
                        trainable,
                        trainable_parameters,
                        initial_parameters,
                        topologies,
                        settings.regularisation_target,
                        str(device),
                    )

                    logger.debug(
                        f"Epoch {i}: Training Weighted Loss: "
                        f"Energy={losses_train.energy.item():.4f} "
                        f"Forces={losses_train.forces.item():.4f} "
                        f"Reg={losses_train.regularisation.item():.4f}"
                    )
                    if i % 10 == 0:
                        losses_test, _ = compute_overall_loss_and_grad(
                            datasets_test,
                            trainable,
                            trainable_parameters,
                            initial_parameters,
                            topologies,
                            settings.regularisation_target,
                            str(device),
                            compute_grad=False,
                        )

                        write_metrics(
                            i,
                            losses_train,
                            losses_test,
                            writer,
                            metrics_file,
                        )

                    trainable_parameters.grad = gradient
                    optimizer.step()
                    optimizer.zero_grad(set_to_none=True)
                    trainable.clamp(trainable_parameters)

                    if i % settings.learning_rate_decay_step == 0:
                        scheduler.step()

        # Required to avoid filling up the GPU memory between iterations
        # TODO: Find a better way to do this.
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            torch.cuda.empty_cache()

        # some book-keeping and outputting
        losses_train, _ = compute_overall_loss_and_grad(
            datasets,
            trainable,
            trainable_parameters,
            initial_parameters,
            topologies,
            settings.regularisation_target,
            str(device),
            compute_grad=False,
        )
        losses_test, _ = compute_overall_loss_and_grad(
            datasets_test,
            trainable,
            trainable_parameters,
            initial_parameters,
            topologies,
            settings.regularisation_target,
            str(device),
            compute_grad=False,
        )

        write_metrics(
            settings.n_epochs, losses_train, losses_test, writer, metrics_file
        )

        return trainable_parameters, trainable