Deploying MLFlow to GCP with Terraform and Bash

At some point in any SysAdmin’s life, they will be asked to manage some ML infrastructure. If they’re lucky, it’ll be something simple.

To help our nascent data science team start tracking their experiments better, we needed MLFlow. Deploying MLFlow itself is fairly simple. You need a database instance and the actual service. There are a thousand different ways to launch a web service and a backing DB, but we needed a few specifics:

  1. This service should be relatively cheap
  2. We only want our users to access this service, so we need this to be behind IAP
  3. Ideally, DB maintenance isn’t our problem.

#1 and #3 aren’t too hard in GCP: We’ll use Cloud Run with minScale at 0 to ensure that the service tears down running instances when they aren’t being accessed. Cool. Money saved. Although we do end up with occasional slow “cold” startups (30+ seconds). Since this is an internal tool that only a few people need access to, this is acceptible. And for the DB, we’ll just use Cloud SQL. Why wouldn’t we?

#2 Means we need a load balancer with our Oauth2 settings added properly.

IaC for MLFlow

Anyway…the point of this post is to have a nice example of doing all this with terraform, so…

data "google_compute_network" "default" {
  name = "default"
}

variable "region" {
  default = "us-central1"
  type    = string
}

variable "mlflow-db-password" {
  type      = string
  sensitive = true
}

variable "oauth2_client_secret" {
  type      = string
  sensitive = true
}

variable "oauth2_client_id" {
  type      = string
  sensitive = true
}

resource "google_secret_manager_secret" "oauth-config" {
  secret_id = "mlflow-oauth-config"

  replication {
    automatic = true
  }
}

resource "google_secret_manager_secret_version" "oauth-config-version" {
  secret = google_secret_manager_secret.oauth-config.id

  secret_data = <<EOT
email_domains = [
    "example.com",
]
provider = "google"
client_id = "${var.oauth2_client_id}"
client_secret = "${var.oauth2_client_secret}"
skip_jwt_bearer_tokens = true
extra_jwt_issuers = "https://accounts.google.com=<>"
cookie_secret = "<COOKIE SECRET IN BASE64>"
EOT
}

resource "google_sql_database_instance" "mlflow" {
  name                = "mlflow"
  region              = var.region
  database_version    = "POSTGRES_14"
  deletion_protection = true

  settings {
    disk_size         = 10
    disk_type         = "PD_HDD"
    tier              = "db-custom-1-4096"
    availability_type = "ZONAL"

    backup_configuration {
      enabled                        = true
      location                       = "us"
      point_in_time_recovery_enabled = false
      backup_retention_settings {
        retained_backups = 2
      }
    }

    insights_config {
      query_insights_enabled  = false
      query_string_length     = 1024
      record_application_tags = false
      record_client_address   = false
    }
    database_flags {
      name  = "log_min_duration_statement"
      value = "1000"
    }
    ip_configuration {
      // this only disables public connections
      //ipv4_enabled    = false
      private_network = data.google_compute_network.default.id
    }
    maintenance_window {
      day  = 7
      hour = 8
    }
  }
}

resource "google_sql_database" "database" {
  name            = "mlflow"
  instance        = google_sql_database_instance.mlflow.name
  deletion_policy = "ABANDON"
}

locals {
  mlflow-db-connection = "postgresql://mlflow:${var.mlflow-db-password}@/mlflow?host=/cloudsql/${google_sql_database_instance.mlflow.connection_name}" : ""
}

resource "google_secret_manager_secret" "mlflow-db-connection" {
  secret_id = "mlflow-db-connection"

  replication {
    automatic = true
  }
}

resource "google_secret_manager_secret_version" "mlflow-db-connection-version" {
  secret = google_secret_manager_secret.mlflow-db-connection.id

  secret_data = local.mlflow-db-connection
}

// backend service is where lb requests are sent
resource "google_compute_backend_service" "mlflow" {
  name  = "mlflow"
  backend {
    group = google_compute_region_network_endpoint_group.mlflow.id
  }
  iap {
    oauth2_client_id     = var.oauth2_client_id
    oauth2_client_secret = var.oauth2_client_secret
  }
}

resource "google_compute_url_map" "mlflow" {
  name  = "mlflow"

  default_url_redirect {
    host_redirect          = "mlflow.example.com"
    redirect_response_code = "MOVED_PERMANENTLY_DEFAULT"
    strip_query            = true
  }

  host_rule {
    hosts        = ["mlflow.example.com"]
    path_matcher = "mlflow"
  }

  path_matcher {
    name            = "mlflow"
    default_service = google_compute_backend_service.mlflow.id
  }
}

// manage certificate
resource "random_id" "mlflow-cert" {
  byte_length = 4
  prefix      = "mlflow-"

  keepers = {
    domains = join(",", ["mlflow.example.com"]
  }
}

resource "google_compute_managed_ssl_certificate" "mlflow" {
  name  = random_id.mlflow-cert.hex

  managed {
    domains = ["mlflow.example.com"]
  }
  lifecycle {
    create_before_destroy = true
  }
}

// configure url mappings and attach certificate
resource "google_compute_target_https_proxy" "mlflow" {
  name    = "mlflow"
  url_map = google_compute_url_map.mlflow.id
  ssl_certificates = [
    google_compute_managed_ssl_certificate.mlflow.self_link,
  ]
}

resource "google_compute_region_network_endpoint_group" "mlflow" {
  name                  = "mlflow"
  network_endpoint_type = "SERVERLESS"
  region                = var.region
  cloud_run {
    service = google_cloud_run_service.mlflow.name
  }
}

resource "google_compute_global_address" "mlflow_lb" {
  name  = "mlflow-lb-address"
}

resource "google_storage_bucket" "mlflow-bucket" {
  storage_class = "STANDARD"
  name          = "example-mlflow"
  location      = upper(var.region)
  force_destroy = true

  uniform_bucket_level_access = true
}

resource "google_cloud_run_service" "mlflow" {
  name     = "mlflow"
  location = var.region

  template {
    spec {
      timeout_seconds = 3600
      containers {
        image = "<mlflow container>"
        resources {
          limits = {
            cpu    = "1000m"
            memory = "1Gi"
          }
        }
        env {
          name  = "DEFAULT_ARTIFACT_ROOT"
          value = "gs://${google_storage_bucket.mlflow-bucket.name}"
        }
        env {
          name = "BACKEND_STORE_URI"
          value_from {
            secret_key_ref {
              key  = "latest"
              name = google_secret_manager_secret.mlflow-db-connection.secret_id
            }
          }
        }
        ports {
          container_port = 8080
        }
      }
    }
    metadata {
      annotations = {
        # Limit scale up to prevent any cost blow outs!
        "autoscaling.knative.dev/maxScale"         = "80"
        "autoscaling.knative.dev/minScale"         = "0"
        "run.googleapis.com/cloudsql-instances"    = google_sql_database_instance.mlflow.connection_name
        "run.googleapis.com/execution-environment" = "gen2"
      }
    }
  }
  autogenerate_revision_name = true
  metadata {
    annotations = {
      # For valid annotation values and descriptions, see
      # https://cloud.google.com/sdk/gcloud/reference/run/deploy#--ingress
      "run.googleapis.com/ingress"     = "internal-and-cloud-load-balancing"
      "client.knative.dev/user-image"  = "<mlflow container>"
    }
  }
  lifecycle {
    ignore_changes = [
      # This annotation appears to be auto-generated somewhere in GCP, so avoid killing it from the terraform side
      metadata.0.annotations["run.googleapis.com/operation-id"],
    ]
  }
  traffic {
    percent         = 100
    latest_revision = true
  }
}

resource "google_compute_global_forwarding_rule" "mlflow_http_forward" {
  name       = "mlflow-http"
  target     = google_compute_target_http_proxy.mlflow_https_redirect.self_link
  ip_address = google_compute_global_address.mlflow_lb.address
  port_range = "80"
}

resource "google_compute_global_forwarding_rule" "mlflow_https" {
  name       = "mlflow"
  target     = google_compute_target_https_proxy.mlflow.self_link
  ip_address = google_compute_global_address.mlflow_lb.address
  port_range = "443"
}

resource "google_compute_url_map" "mlflow_https_redirect" {
  name  = "mlflow-https-redirect"
  default_url_redirect {
    https_redirect         = true
    redirect_response_code = "MOVED_PERMANENTLY_DEFAULT"
    strip_query            = false
  }
}

resource "google_compute_target_http_proxy" "mlflow_https_redirect" {
  name    = "mlflow-http-redirect"
  url_map = google_compute_url_map.mlflow_https_redirect.self_link
}

resource "google_cloud_run_service_iam_binding" "mlflow" {
  location = google_cloud_run_service.mlflow.location
  service  = google_cloud_run_service.mlflow.name
  role     = "roles/run.invoker"
  members = [
    "allUsers"
  ]
}

Building MLFlow

We should also describe how we actually built and managed the MLFlow container. This part is fairly straight-foward, but generally, we need to:

  1. configure python dependencies
    • We like poetry at $JOB, so we’ll use that
  2. create a startup script
  3. Create a Dockerfile
  4. Create a simple script to build and push the container consistently
    • We want to potentially use a service account in GCP for auth, so we should handle both default credentials and service account auth.

pyproject.toml

[tool.poetry]
name = "mlflow"
version = "0.1.0"
description = ""
authors = [""]
readme = "README.md"

[tool.poetry.dependencies]
python = "^3.10"
google-cloud-storage = "^2.7.0"
psycopg2-binary = "^2.9.5"
mlflow = "^2.3.1"
numba = "^0.56.4"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

(We explicitly add numba because of a conflicting package in depdencies, there’s a good chance it could be removed)

Startup script

#!/usr/bin/env bash
set -e

poetry run mlflow server --host 0.0.0.0 --port 8080 --backend-store-uri "${BACKEND_STORE_URI}" --default-artifact-root "${DEFAULT_ARTIFACT_ROOT}" --serve-artifacts

(set -e, while generally discouraged, is nice here to ensure that we exit with an error from the container if the server crashes)

Dockerfile

FROM python:3.10-slim

ENV PYTHONUNBUFFERED=1
ENV PYTHONDONTWRITEBYTECODE=1
ENV PYTHONIOENCODING='utf-8'
ENV LANG='C.UTF-8'

RUN pip install --no-cache-dir --disable-pip-version-check wheel \
    && pip install --no-cache-dir --disable-pip-version-check \
        crcmod \
        poetry \
    && poetry config virtualenvs.create false

WORKDIR /app
COPY pyproject.toml /app
COPY poetry.lock /app

RUN poetry install --only=main \
    && rm -r /root/.cache/pypoetry/cache /root/.cache/pypoetry/artifacts/ \
    && apt-get clean \
    && rm -rf /var/lib/apt/lists/*

COPY start.sh /app/start.sh
RUN chmod +x /app/start.sh

ENTRYPOINT ["/app/start.sh"]

Fairly standard Dockerfile. Most of the boilerplate is around ensuring we don’t eat too much space with pypthon package installs (removing VirtualEnv, artifact downloads, etc.)

Build Script

We want a build script so we can consistently build the container locally and in our CI environment. The nice benefit here is that if the CI environment goes down (like Github CI frequently does) we can still deploy fixes to our services.

#!/usr/bin/env bash

set -eo pipefail

DEPLOY_KEY_FILE=${DEPLOY_KEY_FILE:-""}

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
PROJECT="<project name>"
REGION="${REGION:-us-east1}"
IMG_PATH="gcr.io/${PROJECT}/mlflow"
export IMG_PATH
TAG="${IMG_PATH}:$(date +%Y%m%d-%H%M%S)"
export TAG
DOCKER_IMAGE=$(grep FROM "${SCRIPT_DIR}/docker/Dockerfile" | awk '{print $2}')
export DOCKER_IMAGE
PUSH=${PUSH:-no}
SERVICE_NAME=mlflow

# ensure the base container is up to date
docker pull "${DOCKER_IMAGE}"

# actually build and push our container (with the potential for multi-arch builds)
docker buildx build --tag "${TAG}" --load "${SCRIPT_DIR}/docker"

if [[ "${PUSH}" == "yes" ]]; then
    if [[ -n "${DEPLOY_KEY_FILE}" ]]; then
        echo "${DEPLOY_KEY_FILE}" > /tmp/credentials.json
        gcloud auth activate-service-account --key-file /tmp/credentials.json
        rm -f /tmp/credentials.json

        docker --config /opt/docker-config/ push "${TAG}"
        # ensure we _also_ tag with "latest"
        docker tag "${TAG}" "${IMG_PATH}:latest"
        docker --config /opt/docker-config/ push "${IMG_PATH}:latest"
    else
        docker push "${TAG}"
        # ensure we _also_ tag with "latest"
        docker tag "${TAG}" "${IMG_PATH}:latest"
        docker push "${IMG_PATH}:latest"
    fi
    gcloud run deploy --project "${PROJECT}" --image "${TAG}" --region "${REGION}" "${SERVICE_NAME}"
    ###
    # Clean old revisions
    ###
    KEEP_COUNT=5

    old_revisions=$(gcloud --project "${PROJECT}" run revisions list --region "${REGION}" --format json --service "${SERVICE_NAME}" --filter="status.conditions.type:Active AND status.conditions.status:'False'" --format='value(metadata.name)' | tac)
    if [[ -z "${old_revisions}" ]]; then
        echo "No old revisions found."
        exit
    fi
    versioncount=$(echo "${old_revisions}" | wc -l)

    if [[ "${versioncount}" -le "${KEEP_COUNT}" ]]; then
        echo "Too few revisions of ${SERVICE_NAME} available. Not removing any."
    else
        TO_REMOVE=$(echo "${old_revisions}" | head -"$(( versioncount - KEEP_COUNT ))")
        echo "$(echo "${TO_REMOVE}" | wc -l) objects to remove: $(echo "${TO_REMOVE}" | xargs)"
        if [[ -n "${TO_REMOVE}" ]]; then
            for revision in $(echo "${TO_REMOVE}" | xargs); do
                gcloud --quiet --project "${PROJECT}" run revisions --region "${REGION}" delete "${revision}"
            done
        fi
    fi
fi

Client Access

One last little note: Because we’re using IAP, we need to ensure our experiments can actually connect to MLFlow with proper service account auth. A small script to ensure clients can reach MLFlow:

def enable_mlflow_login(
    sa: str = "<service account email address>",
    client_id: str = "<oauth client id>",
    tracking_server: str = "https://mlflow.example.com",
) -> None:
    """
    Set our MLFLOW_TRACKING_TOKEN environment
    variable correctly for a IAP-protected application
    """
    # import libraries here so we only load them when needed
    from google.cloud import iam_credentials
    import os

    # Get data to login into MLFlow
    client = iam_credentials.IAMCredentialsClient()
    token = client.generate_id_token(
        name=f"projects/-/serviceAccounts/{sa}",
        audience=client_id,
        include_email=True,
    ).token
    # Save data as environment variables
    os.environ["MLFLOW_TRACKING_TOKEN"] = token
    os.environ["MLFLOW_TRACKING_URI"] = tracking_server