Skip to content

Mlflow

Run the Mlflow UI:

uv run --with mlflow==3.7.0 mlflow ui --backend-store-uri sqlite:///mlflow.db

Install packages

1
2
3
4
5
6
7
8
9
!uv pip install -q \
    mlflow==3.7.0 \
    python-dotenv==1.2.1 \
    pandas==2.3.2 \
    pandas-stubs==2.3.2.250827 \
    numpy==2.3.2 \
    matplotlib==3.10.6 \
    seaborn==0.13.2 \
    scikit-learn==1.7.1

Import packages

1
2
3
4
5
6
7
8
import mlflow
import mlflow.sklearn
import pandas as pd
from mlflow.models import infer_signature
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

Set tracking URI

mlflow.set_tracking_uri("http://127.0.0.1:5000")

Load dataset

X, y = datasets.load_iris(return_X_y=True)
X[:5]
array([[5.1, 3.5, 1.4, 0.2],
       [4.9, 3. , 1.4, 0.2],
       [4.7, 3.2, 1.3, 0.2],
       [4.6, 3.1, 1.5, 0.2],
       [5. , 3.6, 1.4, 0.2]])

Split dataset

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

Define model hyperparams

1
2
3
4
5
6
params = {
    "penalty": "l2",
    "solver": "lbfgs",
    "max_iter": 1000,
    "random_state": 8888,
}

Train model

1
2
3
model = LogisticRegression(**params)
model.fit(X_train, y_train)
print(model)
LogisticRegression(max_iter=1000, random_state=8888)

Predict

y_pred = model.predict(X_test)
y_pred
array([2, 1, 2, 0, 2, 0, 1, 1, 0, 0, 0, 0, 1, 2, 0, 2, 1, 1, 0, 2, 1, 2,
       1, 0, 0, 0, 0, 2, 2, 2])

Accuracy evaluation

accuracy = accuracy_score(y_test, y_pred)
accuracy
0.9333333333333333

Create an experiment

mlflow.set_experiment("simple_regression")
2025/12/16 13:43:59 INFO mlflow.tracking.fluent: Experiment with name 'simple_regression' does not exist. Creating a new experiment.
<Experiment: artifact_location='mlflow-artifacts:/1', creation_time=1765903439509, experiment_id='1', last_update_time=1765903439509, lifecycle_stage='active', name='simple_regression', tags={}>
with mlflow.start_run():
    mlflow.log_params(params)

    mlflow.log_metric("accuracy", accuracy)

    mlflow.set_tag("Training Info", "Basic LR model for iris data")

    signature = infer_signature(X_train, model.predict(X_train))

    model_info = mlflow.sklearn.log_model(
        sk_model=model,
        name="iris_model",
        signature=signature,
        input_example=X_train,
        registered_model_name="tracking-simple-regression",
    )
2025/12/16 15:04:03 WARNING mlflow.utils.environment: Failed to resolve installed pip version. ``pip`` will be added to conda.yaml environment spec without a version specifier.

Registered model 'tracking-simple-regression' already exists. Creating a new version of this model...

2025/12/16 15:04:03 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: tracking-simple-regression, version 2

Created version '2' of model 'tracking-simple-regression'.
๐Ÿƒ View run hilarious-bug-698 at: http://127.0.0.1:5000/#/experiments/1/runs/c435fe117ffa4e649d6bf5c0c8f0dd31

๐Ÿงช View experiment at: http://127.0.0.1:5000/#/experiments/1

Load the model

loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
predictions = loaded_model.predict(X_test)

iris_features_names = datasets.load_iris().feature_names

result = pd.DataFrame(X_test, columns=iris_features_names)

result["actual_class"] = y_test
result["predicted_class"] = predictions

result.head()
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) actual_class predicted_class
0 6.4 2.8 5.6 2.2 2 2
1 5.6 2.9 3.6 1.3 1 1
2 6.1 3.0 4.9 1.8 2 2
3 4.6 3.2 1.4 0.2 0 0
4 5.8 2.8 5.1 2.4 2 2

Model Registry

Centralized model store to collaboratively manage the full lifecycle of MLflow Models.

It provides model versioning, stage transitions (e.g., Staging to Production), and annotations.

1
2
3
4
5
6
7
model_name = "tracking-simple-regression"
model_version = "latest"

model_uri = f"models:/{model_name}/{model_version}"

model = mlflow.sklearn.load_model(model_uri)
print(model)
LogisticRegression(max_iter=1000, random_state=8888)