Skip to content

Conversation

@imranq2
Copy link
Contributor

@imranq2 imranq2 commented Sep 2, 2024

This PR enables ESM to run on Mac Silicon (M1, M2, M3) using the Metal Performance Shaders (MPS) backend for GPU training acceleration on Mac Silicon.

PyTorch already supports Mac (MPS): https://pytorch.org/docs/stable/notes/mps.html

Note that MPS does not support the embedding operations so the following environment variable has to be set to allow PyTorch to fallback for those operations:

export PYTORCH_ENABLE_MPS_FALLBACK=1

There aren't any tests so I can't add a unit test for this by following an existing pattern. However I tested on my Macbook by running the following test and the above PR made this work.

import os
from huggingface_hub import login, HfApi
from esm.models.esm3 import ESM3
from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig
import torch

# Check if the Hugging Face API token is available in the environment
token = os.getenv("HF_API_TOKEN")

if token:
    # Use the existing token
    api = HfApi(token=token)
    print("Using existing Hugging Face token.")
else:
    # Prompt the user to log in if no token is found
    login()

# Check that MPS is available
if not torch.backends.mps.is_available():
    if not torch.backends.mps.is_built():
        print("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")
    else:
        print("MPS not available because the current MacOS version is not 12.3+ "
              "and/or you do not have an MPS-enabled device on this machine.")

# Set the device to MPS (for Mac M1/M2) or CPU
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# device = "cpu"
print(f"device: {device}")

# Load the ESM 3.0.4 model
model: ESM3InferenceClient = ESM3.from_pretrained("esm3_sm_open_v1").to(device)

# Check if the model is on MPS
model_device = next(model.parameters()).device
print(f"Model is running on device: {model_device}")

# Example protein sequence
# sequence = "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAN___"
# sequence = "___________________________________________________DQATSLRILNNGHAFGSLTTPP___________________________________________________________"
sequence = "___DQA___"

# Create an ESMProtein object with the sequence
protein = ESMProtein(sequence=sequence)

# Generate the sequence prediction (optional, if needed)
protein = model.generate(protein, GenerationConfig(track="sequence", num_steps=8, temperature=0.7))
# Print out the predicted sequence
predicted_sequence = protein.sequence
print("Predicted Sequence:")
print(predicted_sequence)

# Generate the secondary structure prediction
protein = model.generate(protein, GenerationConfig(track="structure", num_steps=8))

# Save the predicted structure to a PDB file
protein.to_pdb("./predicted_structure.pdb")

# Optionally, perform a round-trip design by inverse folding the sequence and recomputing the structure
protein.sequence = None
protein = model.generate(protein, GenerationConfig(track="sequence", num_steps=8))
protein.coordinates = None
protein = model.generate(protein, GenerationConfig(track="structure", num_steps=8))
protein.to_pdb("./round_tripped_structure.pdb")

print("Secondary structure prediction complete. PDB files saved.")

This prints:
Model is running on device: mps:0
confirming that MPS is being used.

And prints this warning that embedding operation fell back to CPU:
UserWarning: The operator 'aten::_embedding_bag' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications.

@imranq2 imranq2 changed the title fix for Mac silicon Enable running ESM on Mac silicon using MPS Sep 2, 2024
Copy link
Contributor

@gjoliver gjoliver left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ha, this is very cool! thanks a lot for the PR.
can you revert the .gitignore changes? as a python package, we don't need to include dev env related things.

@imranq2
Copy link
Contributor Author

imranq2 commented Oct 9, 2024

ha, this is very cool! thanks a lot for the PR. can you revert the .gitignore changes? as a python package, we don't need to include dev env related things.

Removed .gitignore file

@ebetica
Copy link
Contributor

ebetica commented Oct 11, 2024

We have to set up a CLA before merging this PR 🙈 Sorry for not merging it for so long, I'll get to it hopefully soon.

@imranq2
Copy link
Contributor Author

imranq2 commented Jan 28, 2025

We have to set up a CLA before merging this PR 🙈 Sorry for not merging it for so long, I'll get to it hopefully soon.

Any updates?

Signed-off-by: Zeming Lin <ebetica0@gmail.com>
@ebetica ebetica merged commit c896919 into evolutionaryscale:main Sep 19, 2025
1 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants