Skip to content

Churn Prediction

Run the Mlflow UI:

uv run --with mlflow==3.7.0 mlflow ui --backend-store-uri sqlite:///mlflow.db

Install packages

!uv pip install -q \
    requests==2.32.5

Append notebooks directory to sys.path

1
2
3
import sys

sys.path.append("../../../..")

Import packages

import requests

Training pipeline

app/train.py
# /// script
# requires-python = ">=3.11,<3.13"
# dependencies = [
#     "python-dotenv==1.2.1",
#     "pandas==2.3.2",
#     "pandas-stubs==2.3.2.250827",
#     "numpy==2.3.2",
#     "scikit-learn==1.7.1",
# ]
# ///

import sys

sys.path.append("../../../../..")

import os
import pathlib
import pickle

import pandas as pd
from dotenv import load_dotenv
from sklearn.feature_extraction import DictVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline

from notebooks.python.utils.data_extraction.data_extraction import (
    KaggleDataExtractor,
    KaggleExtractionConfig,
)

pd.set_option("display.max_columns", None)

load_dotenv()  # Root directory .env file

BASE_PATH = pathlib.Path("../../../machine-learning")
DATA_DIR = BASE_PATH / "data/predicting-customer-churn"
OUTPUT_DIR = BASE_PATH / "artifacts/predicting-customer-churn"

DATA_DIR.mkdir(exist_ok=True)
OUTPUT_DIR.mkdir(exist_ok=True)

username = os.getenv("KAGGLE_USERNAME")
api_token = os.getenv("KAGGLE_API_TOKEN")
file_name = "WA_Fn-UseC_-Telco-Customer-Churn.csv"


def load_data():
    extractor = KaggleDataExtractor(username=username, api_token=api_token)

    config = KaggleExtractionConfig(
        dataset_slug="blastchar/telco-customer-churn",
        file_name=file_name,
        destination_path=DATA_DIR,
        output_file_name="churn.csv",
    )

    if not os.path.isfile(DATA_DIR / "churn.csv"):
        extractor.download_dataset(config)

    df = pd.read_csv(DATA_DIR / "churn.csv")

    df.columns = df.columns.str.lower().str.replace(" ", "_")

    categorical_columns = list(df.dtypes[df.dtypes == "object"].index)

    for column in categorical_columns:
        df[column] = df[column].str.lower().str.replace(" ", "_")

    df.totalcharges = pd.to_numeric(df.totalcharges, errors="coerce")
    df.totalcharges = df.totalcharges.fillna(0)

    df.churn = (df.churn == "yes").astype(int)

    return df


def train_model(df):
    numerical = ["tenure", "monthlycharges", "totalcharges"]

    categorical = [
        "gender",
        "seniorcitizen",
        "partner",
        "dependents",
        "phoneservice",
        "multiplelines",
        "internetservice",
        "onlinesecurity",
        "onlinebackup",
        "deviceprotection",
        "techsupport",
        "streamingtv",
        "streamingmovies",
        "contract",
        "paperlessbilling",
        "paymentmethod",
    ]

    y_train = df.churn
    train_dict = df[categorical + numerical].to_dict(orient="records")

    pipeline = make_pipeline(
        DictVectorizer(), LogisticRegression(solver="liblinear")
    )

    pipeline.fit(train_dict, y_train)

    return pipeline


def save_model(pipeline, output_file):
    with open(output_file, "wb") as f_out:
        pickle.dump(pipeline, f_out)


df = load_data()
pipeline = train_model(df)
save_model(pipeline, "model.bin")

Run pipeline

!cd app && \
    uv run train.py
2026/01/16 18:22:31 INFO mlflow.tracking.fluent: Experiment with name 'telco-churn-prediction' does not exist. Creating a new experiment.

2026/01/16 18:22:45 WARNING mlflow.utils.environment: Failed to resolve installed pip version. ``pip`` will be added to conda.yaml environment spec without a version specifier.

🏃 View run enthused-wasp-823 at: http://127.0.0.1:5000/#/experiments/1/runs/5385b8f560694da2b6db3b04a1bff394

🧪 View experiment at: http://127.0.0.1:5000/#/experiments/1

Application code

app/main.py
import pickle
from typing import Literal

import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel, Field


class Customer(BaseModel):
    gender: Literal["male", "female"]
    seniorcitizen: Literal[0, 1]
    partner: Literal["yes", "no"]
    dependents: Literal["yes", "no"]
    phoneservice: Literal["yes", "no"]
    multiplelines: Literal["no", "yes", "no_phone_service"]
    internetservice: Literal["dsl", "fiber_optic", "no"]
    onlinesecurity: Literal["no", "yes", "no_internet_service"]
    onlinebackup: Literal["no", "yes", "no_internet_service"]
    deviceprotection: Literal["no", "yes", "no_internet_service"]
    techsupport: Literal["no", "yes", "no_internet_service"]
    streamingtv: Literal["no", "yes", "no_internet_service"]
    streamingmovies: Literal["no", "yes", "no_internet_service"]
    contract: Literal["month-to-month", "one_year", "two_year"]
    paperlessbilling: Literal["yes", "no"]
    paymentmethod: Literal[
        "electronic_check",
        "mailed_check",
        "bank_transfer_(automatic)",
        "credit_card_(automatic)",
    ]
    tenure: int = Field(..., ge=0)
    monthlycharges: float = Field(..., ge=0.0)
    totalcharges: float = Field(..., ge=0.0)


class PredictResponse(BaseModel):
    churn_probability: float
    churn: bool


app = FastAPI(title="customer-churn-prediction")

with open("model.bin", "rb") as f_in:
    pipeline = pickle.load(f_in)


def predict_single(customer):
    result = pipeline.predict_proba(customer)[0, 1]

    return float(result)


@app.post("/predict")
def predict(customer: Customer) -> PredictResponse:
    prob = predict_single(customer.model_dump())

    return PredictResponse(churn_probability=prob, churn=prob >= 0.5)


if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=9696)

Dockerfile

app/Dockerfile
FROM python:3.11-slim-bookworm

COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
WORKDIR /code

ENV PATH="/code/.venv/bin:$PATH"

COPY "pyproject.toml" "uv.lock" ".python-version" ./
RUN uv sync --locked

COPY "main.py" "model.bin" ./

EXPOSE 9696

ENTRYPOINT ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "9696"]

Dependencies

app/pyproject.toml
[project]
name = "app"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
    "fastapi>=0.128.0",
    "scikit-learn==1.7.1",
    "uvicorn>=0.40.0",
]
!aws --endpoint-url=http://localhost:4566 s3 ls s3://mlflow-artifacts/ --recursive
2026-01-16 18:22:46        652 1/models/m-eeba80cd1ff2478b97fbc45fe5bd1294/artifacts/MLmodel

2026-01-16 18:22:46        222 1/models/m-eeba80cd1ff2478b97fbc45fe5bd1294/artifacts/conda.yaml

2026-01-16 18:22:46       2601 1/models/m-eeba80cd1ff2478b97fbc45fe5bd1294/artifacts/model.pkl

2026-01-16 18:22:46         98 1/models/m-eeba80cd1ff2478b97fbc45fe5bd1294/artifacts/python_env.yaml

2026-01-16 18:22:46        109 1/models/m-eeba80cd1ff2478b97fbc45fe5bd1294/artifacts/requirements.txt

Build docker image

!cd app && \
    docker build -q -t fastapi-predict-churn-mlflow .
sha256:8efad11ce0ac3b0a66539178811863885cc9320d29b350ae14ef99e322b5c795

Run application

!docker run --rm -it \
  --network datadev_datadev-net \
  -d \
  -p 9696:9696 \
  -e MLFLOW_TRACKING_URI=http://mlflow:5000 \
  -e MLFLOW_S3_ENDPOINT_URL=http://localstack:4566 \
  -e AWS_ACCESS_KEY_ID=test \
  -e AWS_SECRET_ACCESS_KEY=test \
  -e AWS_DEFAULT_REGION=us-east-1 \
  -e AWS_EC2_METADATA_DISABLED=true \
  -e RUN_ID=eeba80cd1ff2478b97fbc45fe5bd1294 \
  fastapi-predict-churn-mlflow

Check running applications

!docker ps
CONTAINER ID   IMAGE                          COMMAND                  CREATED          STATUS                      PORTS                                                                  NAMES

917d7dd1e9ef   fastapi-predict-churn-mlflow   "uvicorn main:app --…"   34 seconds ago   Up 33 seconds               0.0.0.0:9696->9696/tcp, [::]:9696->9696/tcp                            practical_thompson

a9bb4ca672cc   datadev-mlflow                 "mlflow server --hos…"   18 minutes ago   Up 18 minutes               0.0.0.0:5000->5000/tcp, [::]:5000->5000/tcp                            mlflow

78fca0cbd827   localstack/localstack:latest   "docker-entrypoint.sh"   18 minutes ago   Up 18 minutes (unhealthy)   4510-4559/tcp, 5678/tcp, 0.0.0.0:4566->4566/tcp, [::]:4566->4566/tcp   localstack

7c3f39be4336   postgres:18-alpine3.22         "docker-entrypoint.s…"   18 minutes ago   Up 18 minutes               0.0.0.0:5432->5432/tcp, [::]:5432->5432/tcp                            mlflow-postgres

Request application

url = "http://localhost:9696/predict"

customer = {
    "gender": "female",
    "seniorcitizen": 0,
    "partner": "yes",
    "dependents": "no",
    "phoneservice": "no",
    "multiplelines": "no_phone_service",
    "internetservice": "dsl",
    "onlinesecurity": "no",
    "onlinebackup": "yes",
    "deviceprotection": "no",
    "techsupport": "no",
    "streamingtv": "no",
    "streamingmovies": "no",
    "contract": "month-to-month",
    "paperlessbilling": "yes",
    "paymentmethod": "electronic_check",
    "tenure": 1,
    "monthlycharges": 29.85,
    "totalcharges": 29.85,
}

response = requests.post(url, json=customer)

predictions = response.json()

print(predictions)
{'churn_probability': 0.6638167617162171, 'churn': True}

Stop container

!docker stop 917d7dd1e9ef
917d7dd1e9ef

Delete container image

!docker rmi fastapi-predict-churn-mlflow
Untagged: fastapi-predict-churn-mlflow:latest

Deleted: sha256:8efad11ce0ac3b0a66539178811863885cc9320d29b350ae14ef99e322b5c795