Serve Gemma open models using TPUs on Vertex AI Prediction with Saxml

This guide shows you how to serve a Gemma open models large language model (LLM) using Tensor Processing Units (TPUs) on Vertex AI Prediction with Saxml. In this guide, you download the 2B and 7B parameter instruction tuned Gemma models to Cloud Storage and deploy them on Vertex AI Prediction that runs Saxml on TPUs.

Background

By serving Gemma using TPUs on Vertex AI Prediction with Saxml. You can take advantage of a managed AI solution that takes care of low level infrastructure and offers a cost effective way for serving LLMs. This section describes the key technologies used in this tutorial.

Gemma

Gemma is a set of openly available, lightweight, and generative artificial intelligence (AI) models released under an open license. These AI models are available to run in your applications, hardware, mobile devices, or hosted services. You can use the Gemma models for text generation, however you can also tune these models for specialized tasks.

To learn more, see the Gemma documentation.

Saxml

Saxml is an experimental system that serves Paxml, JAX, and PyTorch models for inference. For the sake of this tutorial we'll cover how to serve Gemma on TPUs that are more cost efficient for Saxml. Setup for GPUs is similar. Saxml offers scripts to build containers for Vertex AI Prediction that we are going to use in this tutorial.

TPUs

TPUs are Google's custom-developed application-specific integrated circuits (ASICs) used to accelerate data processing frameworks such as TensorFlow, PyTorch, and JAX.

This tutorial serves the Gemma 2B and Gemma 7B models. Vertex AI Prediction hosts these models on the following single-host TPU v5e node pools:

  • Gemma 2B: Hosted in a TPU v5e node pool with 1x1 topology that represents one TPU chip. The machine type for the nodes is ct5lp-hightpu-1t.
  • Gemma 7B: Hosted in a TPU v5e node pool with 2x2 topology that represents four TPU chips. The machine type for the nodes is ct5lp-hightpu-4t.

Before you begin

  1. Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  2. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  3. Make sure that billing is enabled for your Google Cloud project.

  4. Enable the Vertex AI API and Artifact Registry API APIs.

    Enable the APIs

  5. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  6. Make sure that billing is enabled for your Google Cloud project.

  7. Enable the Vertex AI API and Artifact Registry API APIs.

    Enable the APIs

  8. In the Google Cloud console, activate Cloud Shell.

    Activate Cloud Shell

    At the bottom of the Google Cloud console, a Cloud Shell session starts and displays a command-line prompt. Cloud Shell is a shell environment with the Google Cloud CLI already installed and with values already set for your current project. It can take a few seconds for the session to initialize.

This tutorial assumes that you are using Cloud Shell to interact with Google Cloud. If you want to use a different shell instead of Cloud Shell, then perform the following additional configuration:

  1. Install the Google Cloud CLI.
  2. To initialize the gcloud CLI, run the following command:

    gcloud init
  3. Follow the Artifact Registry documentation to Install Docker.
  4. Ensure that you have sufficient quotas for 5 TPU v5e chips for Vertex AI Prediction.
  5. Create a Kaggle account, if you don't already have one.

Get access to the model

To get access to the Gemma models for deployment to Vertex AI Prediction, you must sign in to the Kaggle platform, sign the license consent agreement, and get a Kaggle API token. In this tutorial, you use a Kubernetes Secret for the Kaggle credentials.

You must sign the consent agreement to use Gemma. Follow these instructions:

  1. Access the model consent page on Kaggle.com.
  2. Sign in to Kaggle if you haven't done so already.
  3. Click Request Access.
  4. In the Choose Account for Consent section, select Verify via Kaggle Account to use your Kaggle account for consent.
  5. Accept the model Terms and Conditions.

Generate an access token

To access the model through Kaggle, you need a Kaggle API token.

Follow these steps to generate a new token if you don't have one already:

  1. In your browser, go to Kaggle settings.
  2. Under the API section, click Create New Token.

    A file named kaggle.json is downloaded.

Upload the access token to Cloud Shell

In Cloud Shell, you can upload the Kaggle API token to your Google Cloud project:

  1. In Cloud Shell, click More > Upload.
  2. Select File and click Choose Files.
  3. Open the kaggle.json file.
  4. Click Upload.

Create the Cloud Storage bucket

Create Cloud Storage bucket to store the model checkpoints.

In Cloud Shell, run the following:

gcloud storage buckets create gs://CHECKPOINTS_BUCKET_NAME

Replace the CHECKPOINTS_BUCKET_NAME with the name of the Cloud Storage bucket that stores the model checkpoints.

Copy model to Cloud Storage bucket

In Cloud Shell, run the following:

pip install kaggle --break-system-packages

# For Gemma 2B
mkdir -p /data/gemma_2b-it
kaggle models instances versions download google/gemma/pax/2b-it/1 --untar -p /data/gemma_2b-it
gsutil -m cp -R /data/gemma_2b-it/* gs://CHECKPOINTS_BUCKET_NAME/gemma_2b-it/

# For Gemma 7B
mkdir -p /data/gemma_7b-it
kaggle models instances versions download google/gemma/pax/7b-it/1 --untar -p /data/gemma_7b-it
gsutil -m cp -R /data/gemma_7b-it/* gs://CHECKPOINTS_BUCKET_NAME/gemma_7b-it/

Create an Artifact Registry repository

Create an Artifact Registry repository to store the container image that you will create in the next section.

Enable the Artifact Registry API service for your project.

gcloud services enable artifactregistry.googleapis.com

Run the following command in your shell to create Artifact Registry repository:

gcloud artifacts repositories create saxml \
 --repository-format=docker \
 --location=LOCATION \
 --description="Saxml Docker repository"

Replace LOCATION with the region where Artifact Registry stores your container image. Later, you must create a Vertex AI model resource on a regional endpoint that matches this region, so choose a region where Vertex AI has a regional endpoint, such as us-west1 for TPUs.

Push the container image to Artifact Registry

Prebuilt Saxml container is available at us-docker.pkg.dev/vertex-ai/prediction/sax-tpu:latest. Copy it to your Artifact Registry. Configure Docker to access Artifact Registry. Then push your container image to your Artifact Registry repository.

  1. To give your local Docker installation permission to push to Artifact Registry in your chosen region, run the following command in your shell:

    gcloud auth configure-docker LOCATION-docker.pkg.dev
    
    • Replace LOCATION with the region where you created your repository.
  2. To copy the container image that you just to Artifact Registry, run the following command in your shell:

    docker tag us-docker.pkg.dev/vertex-ai/prediction/sax-tpu:latest LOCATION-docker.pkg.dev/PROJECT_ID/saxml/saxml-tpu:latest
    
  3. To push the container image that you just to Artifact Registry, run the following command in your shell:

    docker push LOCATION-docker.pkg.dev/PROJECT_ID/saxml/saxml-tpu:latest
    

    Replace the following, as you did in the previous section:

    • LOCATION: the region of your Artifact Registry repository.
    • PROJECT_ID: the ID of your Google Cloud project

Deploying the model

Upload a model

To upload a Model resource that uses your Saxml container, run the following gcloud ai models upload command:

Gemma 2B-it

gcloud ai models upload \
  --region=LOCATION \
  --display-name=DEPLOYED_MODEL_NAME \
  --container-image-uri=LOCATION-docker.pkg.dev/PROJECT_ID/saxml/saxml-tpu:latest \
  --artifact-uri='gs://CHECKPOINTS_BUCKET_NAME/gemma_2b-it/' \
  --container-args='--model_path=saxml.server.pax.lm.params.gemma.Gemma2BFP16' \
  --container-args='--platform_chip=tpuv5e' \
  --container-args='--platform_topology=2x2' \
  --container-args='--ckpt_path_suffix=checkpoint_00000000' \
  --container-ports=8502

Gemma 7B-it

gcloud ai models upload \
  --region=LOCATION \
  --display-name=DEPLOYED_MODEL_NAME \
  --container-image-uri=LOCATION-docker.pkg.dev/PROJECT_ID/saxml/saxml-tpu:latest \
  --artifact-uri='gs://CHECKPOINTS_BUCKET_NAME/gemma_7b-it/' \
  --container-args='--model_path=saxml.server.pax.lm.params.gemma.Gemma7BFP16' \
  --container-args='--platform_chip=tpuv5e' \
  --container-args='--platform_topology=2x2' \
  --container-args='--ckpt_path_suffix=checkpoint_00000000' \
  --container-ports=8502

Replace the following:

  • PROJECT_ID: the ID of your Google Cloud project
  • LOCATION_ID: The region where you are using Vertex AI. Note that TPUs are only available in us-west1.
  • DEPLOYED_MODEL_NAME: A name for the DeployedModel. You can use the display name of the Model for the DeployedModel as well.

Create an endpoint

You must deploy the model to an endpoint before the model can be used to serve online predictions. If you are deploying a model to an existing endpoint, you can skip this step. The following example uses the gcloud ai endpoints create command:

gcloud ai endpoints create \
  --region=LOCATION \
  --display-name=ENDPOINT_NAME

Replace the following:

  • LOCATION_ID: The region where you are using Vertex AI.
  • ENDPOINT_NAME: The display name for the endpoint.

The Google Cloud CLI tool might take a few seconds to create the endpoint.

Deploy the model to endpoint

After the endpoint is ready, deploy the model to the endpoint.

ENDPOINT_ID=$(gcloud ai endpoints list \
   --region=LOCATION \
   --filter=display_name=ENDPOINT_NAME \
   --format="value(name)")

MODEL_ID=$(gcloud ai models list \
   --region=LOCATION \
   --filter=display_name=DEPLOYED_MODEL_NAME \
   --format="value(name)")

gcloud ai endpoints deploy-model $ENDPOINT_ID \
  --region=LOCATION \
  --model=$MODEL_ID \
  --display-name=DEPLOYED_MODEL_NAME \
  --machine-type=ct5lp-hightpu-4t \
  --traffic-split=0=100

Replace the following:

  • LOCATION_ID: The region where you are using Vertex AI.
  • ENDPOINT_NAME: The display name for the endpoint.
  • DEPLOYED_MODEL_NAME: A name for the DeployedModel. You can use the display name of the Model for the DeployedModel as well.

Gemma 2B can be deployed on a smaller ct5lp-hightpu-1t machine, in such case you should specify --platform_topology=1x1 when uploading model.

The Google Cloud CLI tool might take a few minutes to deploy the model to the endpoint. When the model is successfully deployed, this command prints the following output:

  Deployed a model to the endpoint xxxxx. Id of the deployed model: xxxxx.

Getting online predictions from the deployed model

To invoke the model through the Vertex AI Prediction endpoint, format the prediction request by using a standard Inference Request JSON Object .

The following example uses the gcloud ai endpoints predict command:

ENDPOINT_ID=$(gcloud ai endpoints list \
   --region=LOCATION \
   --filter=display_name=ENDPOINT_NAME \
   --format="value(name)")

gcloud ai endpoints predict $ENDPOINT_ID \
  --region=LOCATION \
  --http-headers=Content-Type=application/json \
  --json-request instances.json

Replace the following:

  • LOCATION_ID: The region where you are using Vertex AI.
  • ENDPOINT_NAME: The display name for the endpoint.
  • instances.json has following format: {"instances": [{"text_batch": "<your prompt>"},{...}]}

Cleaning up

To avoid incurring further Vertex AI charges and Artifact Registry charges, delete the Google Cloud resources that you created during this tutorial:

  1. To undeploy model from endpoint and delete the endpoint, run the following command in your shell:

    ENDPOINT_ID=$(gcloud ai endpoints list \
       --region=LOCATION \
       --filter=display_name=ENDPOINT_NAME \
       --format="value(name)")
    
    DEPLOYED_MODEL_ID=$(gcloud ai endpoints describe $ENDPOINT_ID \
       --region=LOCATION \
       --format="value(deployedModels.id)")
    
    gcloud ai endpoints undeploy-model $ENDPOINT_ID \
      --region=LOCATION \
      --deployed-model-id=$DEPLOYED_MODEL_ID
    
    gcloud ai endpoints delete $ENDPOINT_ID \
       --region=LOCATION \
       --quiet
    

    Replace LOCATION with the region where you created your model in a previous section.

  2. To delete your model, run the following command in your shell:

    MODEL_ID=$(gcloud ai models list \
       --region=LOCATION \
       --filter=display_name=DEPLOYED_MODEL_NAME \
       --format="value(name)")
    
    gcloud ai models delete $MODEL_ID \
       --region=LOCATION \
       --quiet
    

    Replace LOCATION with the region where you created your model in a previous section.

  3. To delete your Artifact Registry repository and the container image in it, run the following command in your shell:

    gcloud artifacts repositories delete saxml \
      --location=LOCATION \
      --quiet
    

    Replace LOCATION with the region where you created your Artifact Registry repository in a previous section.

Limitations

  • On Vertex AI Prediction Cloud TPUs are supported only in us-west1. For more information, see locations.

What's next

  • Learn how to deploy other Saxml models such as Llama2 and GPT-J.