Skip to content

compile_aimnet2_ens_models #

Script to compile AIMNet2 ensemble models for use in presto.

Classes:

Functions:

EnsembledModel #

EnsembledModel(
    models: List,
    x=["coord", "numbers", "charge"],
    out=["energy", "forces", "charges"],
    detach=True,
)

Bases: Module

Create ensemble of AIMNet2 models.

Source code in presto/models/compile_aimnet2_ens_models.py
def __init__(
    self,
    models: List,
    x=["coord", "numbers", "charge"],
    out=["energy", "forces", "charges"],
    detach=True,
):
    super().__init__()
    self.models = nn.ModuleList(models)
    self.x = x
    self.out = out
    self.detach = detach

_download_model #

_download_model(
    method: str,
    version: int = 0,
    device: TorchDevice | None = None,
) -> ScriptModule

Download an AIMNet2 model directly from storage.

Source code in presto/models/compile_aimnet2_ens_models.py
def _download_model(
    method: str, version: int = 0, device: TorchDevice | None = None
) -> torch.jit.ScriptModule:
    """Download an AIMNet2 model directly from storage."""
    url = f"{_MODEL_URL}{method}_{version}.jpt"

    with tempfile.NamedTemporaryFile(suffix=".jpt") as tmp_file:
        urllib.request.urlretrieve(url, filename=tmp_file.name)
        model: torch.jit.ScriptModule = torch.jit.load(  # type: ignore[no-untyped-call]
            tmp_file.name, map_location=device
        )
        return model

compile_aimnet2_ens_model #

compile_aimnet2_ens_model(
    model_name: AvailableModels,
    n_members: int = 4,
    device: TorchDevice = "cpu",
) -> ScriptModule

Compile an AIMNet2 ensemble model.

Args: model_name: Name of the AIMNet2 model to compile. n_members: Number of ensemble members to include. device: Torch device to load models onto.

Returns: Compiled AIMNet2 ensemble model.

Source code in presto/models/compile_aimnet2_ens_models.py
def compile_aimnet2_ens_model(
    model_name: AvailableModels,
    n_members: int = 4,
    device: TorchDevice = "cpu",
) -> torch.jit.ScriptModule:
    """Compile an AIMNet2 ensemble model.

    Args:
        model_name: Name of the AIMNet2 model to compile.
        n_members: Number of ensemble members to include.
        device: Torch device to load models onto.

    Returns:
        Compiled AIMNet2 ensemble model.
    """
    if model_name not in get_args(AvailableModels):
        raise ValueError(
            f"Invalid model name: {model_name}. Available models are: {get_args(AvailableModels)}"
        )

    models = []
    for i in range(n_members):
        model = _download_model(model_name, version=i, device=device)
        models.append(model)

    ensemble_model = EnsembledModel(models=models, detach=False)
    scripted_model = torch.jit.script(ensemble_model)  # type: ignore[no-untyped-call]

    return scripted_model

main #

main()

Main function to compile and save AIMNet2 ensemble models.

Source code in presto/models/compile_aimnet2_ens_models.py
def main():
    """Main function to compile and save AIMNet2 ensemble models."""
    for model_name in get_args(AvailableModels):
        logger.info(f"Compiling ensemble model for {model_name}...")
        ens_model = compile_aimnet2_ens_model(model_name, n_members=4, device="cpu")
        save_path = f"{model_name}_ens.jpt"
        ens_model.save(save_path)
        logger.info(f"Saved ensemble model to {save_path}")