Skip to content

API Reference

Tracking Store

DynamoDBTrackingStore

DynamoDBTrackingStore(store_uri, artifact_uri=None)

Bases: AbstractStore

MLflow tracking store backed by DynamoDB.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def __init__(
    self,
    store_uri: str,
    artifact_uri: str | None = None,
) -> None:
    uri = parse_dynamodb_uri(store_uri)
    if uri.deploy:
        ensure_stack_exists(uri.table_name, uri.region, uri.endpoint_url)
    self._table = DynamoDBTable(uri.table_name, uri.region, uri.endpoint_url)
    self._uri = uri
    self._cache = ResolutionCache(workspace=lambda: self._workspace)
    self._artifact_uri = artifact_uri or "./mlartifacts"
    self.artifact_root_uri = self._artifact_uri
    self._config = ConfigReader(self._table)
    self._config.reconcile()
    super().__init__()

supports_workspaces property

supports_workspaces

DynamoDB store always supports workspaces.

Workspace scoping is built into the schema (GSI2/GSI3 prefixes, META workspace attribute). The --enable-workspaces server flag controls whether workspace features are active at runtime.

create_experiment

create_experiment(name, artifact_location=None, tags=None)

Create a new experiment and return its ID (ULID).

Source code in src/mlflow_dynamodbstore/tracking_store.py
def create_experiment(
    self,
    name: str,
    artifact_location: str | None = None,
    tags: list[ExperimentTag] | None = None,
) -> str:
    """Create a new experiment and return its ID (ULID)."""
    # Check uniqueness via GSI3
    existing = self._table.query(
        pk=f"{GSI3_EXP_NAME_PREFIX}{self._workspace}#{name}",
        index_name="gsi3",
        limit=1,
    )
    if existing:
        raise MlflowException(
            f"Experiment(name={name}) already exists.",
            error_code=RESOURCE_ALREADY_EXISTS,
        )

    now_ms = int(time.time() * 1000)
    exp_id = generate_ulid()

    item: dict[str, Any] = {
        "PK": f"{PK_EXPERIMENT_PREFIX}{exp_id}",
        "SK": SK_EXPERIMENT_META,
        "name": name,
        "lifecycle_stage": "active",
        "artifact_location": artifact_location or self._artifact_uri,
        "creation_time": now_ms,
        "last_update_time": now_ms,
        "workspace": self._workspace,
        "tags": {},
        # LSI attributes
        LSI1_SK: f"active#{exp_id}",
        LSI2_SK: now_ms,
        LSI3_SK: name,
        LSI4_SK: _rev(name),
        # GSI2: list experiments by lifecycle
        GSI2_PK: f"{GSI2_EXPERIMENTS_PREFIX}{self._workspace}#active",
        GSI2_SK: exp_id,
        # GSI3: unique name lookup
        GSI3_PK: f"{GSI3_EXP_NAME_PREFIX}{self._workspace}#{name}",
        GSI3_SK: exp_id,
        # GSI5: all experiment names
        GSI5_PK: f"{GSI5_EXP_NAMES_PREFIX}{self._workspace}",
        GSI5_SK: f"{name}#{exp_id}",
    }

    self._table.put_item(item, condition="attribute_not_exists(PK)")

    # Write NAME_REV item for suffix ILIKE support (GSI5)
    name_rev_item = {
        "PK": f"{PK_EXPERIMENT_PREFIX}{exp_id}",
        "SK": SK_EXPERIMENT_NAME_REV,
        GSI5_PK: f"{GSI5_EXP_NAMES_PREFIX}{self._workspace}",
        GSI5_SK: f"REV#{_rev(name.lower())}#{exp_id}",
        "name": name,
    }
    self._table.put_item(name_rev_item)

    # Write tags if provided
    if tags:
        for tag in tags:
            self._write_experiment_tag(exp_id, tag)

    # Write FTS items for experiment name
    fts_items = fts_items_for_text(
        pk=f"{PK_EXPERIMENT_PREFIX}{exp_id}",
        entity_type="E",
        entity_id=exp_id,
        field=None,
        text=name,
        workspace=self._workspace,
    )
    if fts_items:
        self._table.batch_write(fts_items)

    return exp_id

get_experiment

get_experiment(experiment_id)

Fetch an experiment by ID.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def get_experiment(self, experiment_id: str) -> Experiment:
    """Fetch an experiment by ID."""
    item = self._table.get_item(
        pk=f"{PK_EXPERIMENT_PREFIX}{experiment_id}",
        sk=SK_EXPERIMENT_META,
    )
    if item is None:
        raise MlflowException(
            f"No Experiment with id={experiment_id} exists",
            error_code=RESOURCE_DOES_NOT_EXIST,
        )

    item_workspace = item.get("workspace", "default")
    if item_workspace != self._workspace:
        raise MlflowException(
            f"No Experiment with id={experiment_id} exists",
            error_code=RESOURCE_DOES_NOT_EXIST,
        )

    tags = self._get_experiment_tags(experiment_id)
    return _item_to_experiment(item, tags)

get_experiment_by_name

get_experiment_by_name(experiment_name)

Fetch an experiment by name, or None if not found.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def get_experiment_by_name(self, experiment_name: str) -> Experiment | None:
    """Fetch an experiment by name, or None if not found."""
    results = self._table.query(
        pk=f"{GSI3_EXP_NAME_PREFIX}{self._workspace}#{experiment_name}",
        index_name="gsi3",
        limit=1,
    )
    if not results:
        return None

    exp_id = results[0]["PK"].replace(PK_EXPERIMENT_PREFIX, "")
    return self.get_experiment(exp_id)

rename_experiment

rename_experiment(experiment_id, new_name)

Rename an experiment.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def rename_experiment(self, experiment_id: str, new_name: str) -> None:
    """Rename an experiment."""
    exp = self.get_experiment(experiment_id)
    old_name = exp.name

    # Check new name uniqueness
    existing = self._table.query(
        pk=f"{GSI3_EXP_NAME_PREFIX}{self._workspace}#{new_name}",
        index_name="gsi3",
        limit=1,
    )
    if existing:
        raise MlflowException(
            f"Experiment(name={new_name}) already exists.",
            error_code=RESOURCE_ALREADY_EXISTS,
        )

    now_ms = int(time.time() * 1000)

    self._table.update_item(
        pk=f"{PK_EXPERIMENT_PREFIX}{experiment_id}",
        sk=SK_EXPERIMENT_META,
        updates={
            "name": new_name,
            "last_update_time": now_ms,
            LSI2_SK: now_ms,
            LSI3_SK: new_name,
            LSI4_SK: _rev(new_name),
            GSI3_PK: f"{GSI3_EXP_NAME_PREFIX}{self._workspace}#{new_name}",
            GSI5_SK: f"{new_name}#{experiment_id}",
        },
    )

    # Update NAME_REV item with new reversed name
    self._table.update_item(
        pk=f"{PK_EXPERIMENT_PREFIX}{experiment_id}",
        sk=SK_EXPERIMENT_NAME_REV,
        updates={
            GSI5_SK: f"REV#{_rev(new_name.lower())}#{experiment_id}",
            "name": new_name,
        },
    )

    # Update FTS items for experiment name change
    levels = ("W", "3", "2")  # always trigram for experiment_name
    tokens_to_add, tokens_to_remove = fts_diff(old_name, new_name, levels)
    pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"

    # Delete removed FTS items by looking up via reverse SK prefix
    if tokens_to_remove:
        rev_prefix = f"{SK_FTS_REV_PREFIX}E#{experiment_id}#"
        rev_items = self._table.query(pk=pk, sk_prefix=rev_prefix)
        for rev_item in rev_items:
            # Parse level and token from reverse SK: FTS_REV#E#<id>#<level>#<token>
            rev_sk = rev_item["SK"]
            # Build the corresponding forward SK to delete it too
            # rev_sk pattern: FTS_REV#E#<entity_id>#<level>#<token>
            parts = rev_sk[len(SK_FTS_REV_PREFIX) :].split("#")
            # parts = ["E", experiment_id, level, token]
            if len(parts) >= 4:
                lvl, tok = parts[2], parts[3]
                if (lvl, tok) in tokens_to_remove:
                    forward_sk = f"{SK_FTS_PREFIX}{lvl}#E#{tok}#{experiment_id}"
                    self._table.delete_item(pk=pk, sk=forward_sk)
                    self._table.delete_item(pk=pk, sk=rev_sk)

    # Write new FTS items for added tokens
    if tokens_to_add:
        new_fts_items: list[dict[str, Any]] = []
        for lvl, tok in tokens_to_add:
            forward_sk = f"{SK_FTS_PREFIX}{lvl}#E#{tok}#{experiment_id}"
            reverse_sk = f"{SK_FTS_REV_PREFIX}E#{experiment_id}#{lvl}#{tok}"
            gsi2pk_val = f"{GSI2_FTS_NAMES_PREFIX}{self._workspace}"
            gsi2sk_val = f"{lvl}#E#{tok}#{experiment_id}"
            new_fts_items.append(
                {"PK": pk, "SK": forward_sk, GSI2_PK: gsi2pk_val, GSI2_SK: gsi2sk_val}
            )
            new_fts_items.append({"PK": pk, "SK": reverse_sk})
        self._table.batch_write(new_fts_items)

    # Invalidate cache
    self._cache.invalidate("exp_name", old_name)

delete_experiment

delete_experiment(experiment_id)

Soft-delete an experiment and set TTL on META only.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def delete_experiment(self, experiment_id: str) -> None:
    """Soft-delete an experiment and set TTL on META only."""
    self._check_experiment_workspace(experiment_id)
    now_ms = int(time.time() * 1000)

    updates: dict[str, Any] = {
        "lifecycle_stage": "deleted",
        "last_update_time": now_ms,
        LSI1_SK: f"deleted#{experiment_id}",
        LSI2_SK: now_ms,
        GSI2_PK: f"{GSI2_EXPERIMENTS_PREFIX}{self._workspace}#deleted",
    }

    # Compute TTL if enabled
    ttl_seconds = self._config.get_soft_deleted_ttl_seconds()
    if ttl_seconds is not None:
        updates["ttl"] = int(time.time()) + ttl_seconds

    self._table.update_item(
        pk=f"{PK_EXPERIMENT_PREFIX}{experiment_id}",
        sk=SK_EXPERIMENT_META,
        updates=updates,
    )

restore_experiment

restore_experiment(experiment_id)

Restore a soft-deleted experiment and remove TTL from META.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def restore_experiment(self, experiment_id: str) -> None:
    """Restore a soft-deleted experiment and remove TTL from META."""
    self._check_experiment_workspace(experiment_id)
    now_ms = int(time.time() * 1000)

    self._table.update_item(
        pk=f"{PK_EXPERIMENT_PREFIX}{experiment_id}",
        sk=SK_EXPERIMENT_META,
        updates={
            "lifecycle_stage": "active",
            "last_update_time": now_ms,
            LSI1_SK: f"active#{experiment_id}",
            LSI2_SK: now_ms,
            GSI2_PK: f"{GSI2_EXPERIMENTS_PREFIX}{self._workspace}#active",
        },
        removes=["ttl"],
    )

search_experiments

search_experiments(view_type=ACTIVE_ONLY, max_results=1000, filter_string=None, order_by=None, page_token=None)

Search experiments with filter and order_by support.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def search_experiments(
    self,
    view_type: int = ViewType.ACTIVE_ONLY,
    max_results: int = 1000,
    filter_string: str | None = None,
    order_by: list[str] | None = None,
    page_token: str | None = None,
) -> list[Experiment]:
    """Search experiments with filter and order_by support."""
    from mlflow_dynamodbstore.dynamodb.search import parse_experiment_filter

    predicates = parse_experiment_filter(filter_string)

    # Classify predicates
    name_pred = next(
        (p for p in predicates if p.field_type == "attribute" and p.key == "name"),
        None,
    )
    tag_preds = [p for p in predicates if p.field_type == "tag"]

    if name_pred and name_pred.op == "=":
        experiments = self._search_experiments_by_name_exact(name_pred.value, view_type)
    elif name_pred and name_pred.op in ("LIKE", "ILIKE"):
        experiments = self._search_experiments_by_name_like(
            name_pred.value, name_pred.op, view_type
        )
    else:
        experiments = self._search_experiments_by_lifecycle(view_type, max_results)

    # Apply tag filters as post-filters
    if tag_preds:
        experiments = self._filter_experiments_by_tags(experiments, tag_preds)

    # Apply ordering
    if order_by:
        experiments = self._sort_experiments(experiments, order_by)

    from mlflow.store.entities import PagedList

    results = experiments[:max_results] if max_results else experiments
    return PagedList(results, token=None)

set_experiment_tag

set_experiment_tag(experiment_id, tag)

Set a tag on an experiment.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def set_experiment_tag(self, experiment_id: str, tag: ExperimentTag) -> None:
    """Set a tag on an experiment."""
    # Verify experiment exists
    self.get_experiment(experiment_id)
    self._write_experiment_tag(experiment_id, tag)

delete_experiment_tag

delete_experiment_tag(experiment_id, key)

Delete a tag from an experiment.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def delete_experiment_tag(self, experiment_id: str, key: str) -> None:
    """Delete a tag from an experiment."""
    self.get_experiment(experiment_id)
    pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"
    sk = f"{SK_EXPERIMENT_TAG_PREFIX}{key}"
    self._table.delete_item(pk=pk, sk=sk)
    if self._config.should_denormalize(None, key):
        self._remove_denormalized_tag(pk, SK_EXPERIMENT_META, key)

create_run

create_run(experiment_id, user_id, start_time, tags, run_name)

Create a new run within an experiment.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def create_run(
    self,
    experiment_id: str,
    user_id: str,
    start_time: int,
    tags: list[RunTag],
    run_name: str,
) -> Run:
    """Create a new run within an experiment."""
    # Verify experiment exists
    exp = self.get_experiment(experiment_id)

    run_id = ulid_from_timestamp(start_time)
    artifact_uri = f"{exp.artifact_location}/{run_id}/artifacts"

    # Resolve run_name: use tag value, generate random name, or keep provided
    from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME
    from mlflow.utils.name_utils import _generate_random_name

    run_name_tag = next((t for t in (tags or []) if t.key == MLFLOW_RUN_NAME), None)
    if run_name and run_name_tag and run_name != run_name_tag.value:
        raise MlflowException(
            f"Both 'run_name' argument and 'mlflow.runName' tag are specified, but with "
            f"different values (run_name='{run_name}', "
            f"mlflow.runName='{run_name_tag.value}').",
            error_code=INVALID_PARAMETER_VALUE,
        )
    run_name = (
        run_name or (run_name_tag.value if run_name_tag else None) or _generate_random_name()
    )

    item: dict[str, Any] = {
        "PK": f"{PK_EXPERIMENT_PREFIX}{experiment_id}",
        "SK": f"{SK_RUN_PREFIX}{run_id}",
        "run_id": run_id,
        "experiment_id": experiment_id,
        "user_id": user_id,
        "status": "RUNNING",
        "start_time": start_time,
        "run_name": run_name,
        "lifecycle_stage": "active",
        "artifact_uri": artifact_uri,
        "tags": {},
        # LSI attributes
        LSI1_SK: f"active#{run_id}",
        LSI3_SK: f"RUNNING#{run_id}",
        # GSI1: reverse lookup run_id -> experiment_id
        GSI1_PK: f"{GSI1_RUN_PREFIX}{run_id}",
        GSI1_SK: f"{PK_EXPERIMENT_PREFIX}{experiment_id}",
    }

    # LSI4 is sparse — only set when run_name is non-empty (DynamoDB rejects empty string keys)
    if run_name:
        item[LSI4_SK] = run_name.lower()

    self._table.put_item(item, condition="attribute_not_exists(PK)")

    # Build the full tag list for the returned Run entity
    all_tags: list[RunTag] = list(tags or [])

    # Write mlflow.runName system tag
    if run_name and not run_name_tag:
        run_name_sys_tag = RunTag(MLFLOW_RUN_NAME, run_name)
        self._write_run_tag(experiment_id, run_id, run_name_sys_tag)
        all_tags.append(run_name_sys_tag)

    # Write tags if provided
    if tags:
        for tag in tags:
            self._write_run_tag(experiment_id, run_id, tag)

    # Cache run_id -> experiment_id
    self._cache.put("run_exp", run_id, experiment_id)

    # Write FTS items for run name
    if run_name:
        run_fts = fts_items_for_text(
            pk=f"{PK_EXPERIMENT_PREFIX}{experiment_id}",
            entity_type="R",
            entity_id=run_id,
            field=None,
            text=run_name,
        )
        if run_fts:
            self._table.batch_write(run_fts)

    return self._build_run(item, all_tags)

get_run

get_run(run_id)

Fetch a run by ID.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def get_run(self, run_id: str) -> Run:
    """Fetch a run by ID."""
    experiment_id = self._resolve_run_experiment(run_id)

    item = self._table.get_item(
        pk=f"{PK_EXPERIMENT_PREFIX}{experiment_id}",
        sk=f"{SK_RUN_PREFIX}{run_id}",
    )
    if item is None:
        raise MlflowException(
            f"Run '{run_id}' does not exist.",
            error_code=RESOURCE_DOES_NOT_EXIST,
        )

    # Query tags, params, metrics for this run
    run_prefix = f"{SK_RUN_PREFIX}{run_id}"
    tag_items = self._table.query(
        pk=f"{PK_EXPERIMENT_PREFIX}{experiment_id}",
        sk_prefix=f"{run_prefix}{SK_TAG_PREFIX}",
    )
    param_items = self._table.query(
        pk=f"{PK_EXPERIMENT_PREFIX}{experiment_id}",
        sk_prefix=f"{run_prefix}{SK_PARAM_PREFIX}",
    )
    metric_items = self._table.query(
        pk=f"{PK_EXPERIMENT_PREFIX}{experiment_id}",
        sk_prefix=f"{run_prefix}{SK_METRIC_PREFIX}",
    )

    # Query input items (INPUT links and their ITAG children share the same prefix)
    all_input_items = self._table.query(
        pk=f"{PK_EXPERIMENT_PREFIX}{experiment_id}",
        sk_prefix=f"{run_prefix}{SK_INPUT_PREFIX}",
    )
    # Separate INPUT link items from ITAG items
    input_items = []
    input_tag_items = []
    for it in all_input_items:
        if SK_INPUT_TAG_SUFFIX in it["SK"]:
            input_tag_items.append(it)
        else:
            input_items.append(it)

    # Query dataset items for this experiment
    dataset_items = []
    if input_items:
        dataset_items = self._table.query(
            pk=f"{PK_EXPERIMENT_PREFIX}{experiment_id}",
            sk_prefix=SK_DATASET_PREFIX,
        )

    return _item_to_run(
        item,
        tag_items,
        param_items,
        metric_items,
        input_items=input_items,
        dataset_items=dataset_items,
        input_tag_items=input_tag_items,
    )

create_logged_model

create_logged_model(experiment_id, name=None, source_run_id=None, tags=None, params=None, model_type=None)

Create a new logged model within an experiment.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def create_logged_model(
    self,
    experiment_id: str,
    name: str | None = None,
    source_run_id: str | None = None,
    tags: list[LoggedModelTag] | None = None,
    params: list[LoggedModelParameter] | None = None,
    model_type: str | None = None,
) -> LoggedModel:
    """Create a new logged model within an experiment."""
    exp = self.get_experiment(experiment_id)

    now_ms = int(time.time() * 1000)
    model_id = f"m-{generate_ulid()}"
    name = name or model_id
    artifact_location = f"{exp.artifact_location}/models/{model_id}/artifacts/"

    pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"
    sk = f"{SK_LM_PREFIX}{model_id}"

    tag_dict = {t.key: t.value for t in (tags or [])}
    param_dict = {p.key: p.value for p in (params or [])}

    item: dict[str, Any] = {
        "PK": pk,
        "SK": sk,
        "model_id": model_id,
        "experiment_id": experiment_id,
        "name": name,
        "artifact_location": artifact_location,
        "creation_timestamp_ms": now_ms,
        "last_updated_timestamp_ms": now_ms,
        "status": str(LoggedModelStatus.PENDING),
        "lifecycle_stage": "active",
        "model_type": model_type or "",
        "source_run_id": source_run_id or "",
        "status_message": "",
        "tags": tag_dict,
        "params": param_dict,
        "workspace": self._workspace,
        # LSI projections for filtering/sorting
        LSI1_SK: f"active#{model_id}",
        LSI2_SK: now_ms,
        LSI3_SK: f"PENDING#{model_id}",
        LSI4_SK: name.lower(),
        # GSI1: reverse lookup model_id -> experiment_id
        GSI1_PK: f"{GSI1_LM_PREFIX}{model_id}",
        GSI1_SK: f"{PK_EXPERIMENT_PREFIX}{experiment_id}",
    }

    self._table.put_item(item, condition="attribute_not_exists(SK)")

    # Write tag sub-items
    for key, value in tag_dict.items():
        self._table.put_item(
            {
                "PK": pk,
                "SK": f"{SK_LM_PREFIX}{model_id}{SK_LM_TAG_PREFIX}{key}",
                "key": key,
                "value": value,
            }
        )

    # Write param sub-items
    for key, value in param_dict.items():
        self._table.put_item(
            {
                "PK": pk,
                "SK": f"{SK_LM_PREFIX}{model_id}{SK_LM_PARAM_PREFIX}{key}",
                "key": key,
                "value": value,
            }
        )

    return _item_to_logged_model(item)

get_logged_model

get_logged_model(model_id, allow_deleted=False)

Fetch a logged model by ID.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def get_logged_model(self, model_id: str, allow_deleted: bool = False) -> LoggedModel:
    """Fetch a logged model by ID."""
    experiment_id = self._resolve_logged_model_experiment(model_id)
    pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"
    sk = f"{SK_LM_PREFIX}{model_id}"

    meta = self._table.get_item(pk=pk, sk=sk)
    if meta is None:
        raise MlflowException(
            f"Logged model with ID '{model_id}' not found.",
            error_code=RESOURCE_DOES_NOT_EXIST,
        )

    if meta.get("lifecycle_stage") == "deleted" and not allow_deleted:
        raise MlflowException(
            f"Logged model with ID '{model_id}' not found.",
            error_code=RESOURCE_DOES_NOT_EXIST,
        )

    # Load sub-items
    tag_items = self._table.query(
        pk=pk, sk_prefix=f"{SK_LM_PREFIX}{model_id}{SK_LM_TAG_PREFIX}"
    )
    param_items = self._table.query(
        pk=pk, sk_prefix=f"{SK_LM_PREFIX}{model_id}{SK_LM_PARAM_PREFIX}"
    )
    metric_items = self._table.query(
        pk=pk, sk_prefix=f"{SK_LM_PREFIX}{model_id}{SK_LM_METRIC_PREFIX}"
    )

    return _item_to_logged_model(meta, tag_items, param_items, metric_items)

finalize_logged_model

finalize_logged_model(model_id, status)

Update a logged model's status (e.g. READY or FAILED).

Source code in src/mlflow_dynamodbstore/tracking_store.py
def finalize_logged_model(self, model_id: str, status: LoggedModelStatus) -> LoggedModel:
    """Update a logged model's status (e.g. READY or FAILED)."""
    experiment_id = self._resolve_logged_model_experiment(model_id)
    pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"
    sk = f"{SK_LM_PREFIX}{model_id}"
    now_ms = int(time.time() * 1000)

    self._table.update_item(
        pk=pk,
        sk=sk,
        updates={
            "status": str(status),
            "last_updated_timestamp_ms": now_ms,
            LSI3_SK: f"{status}#{model_id}",
        },
    )
    return self.get_logged_model(model_id)

delete_logged_model

delete_logged_model(model_id)

Soft-delete a logged model and set TTL on all related items.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def delete_logged_model(self, model_id: str) -> None:
    """Soft-delete a logged model and set TTL on all related items."""
    experiment_id = self._resolve_logged_model_experiment(model_id)
    pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"
    sk = f"{SK_LM_PREFIX}{model_id}"
    now_ms = int(time.time() * 1000)

    ttl_seconds = self._config.get_soft_deleted_ttl_seconds()
    ttl_value = int(time.time()) + ttl_seconds if ttl_seconds is not None else None

    updates: dict[str, Any] = {
        "lifecycle_stage": "deleted",
        "last_updated_timestamp_ms": now_ms,
        LSI1_SK: f"deleted#{model_id}",
    }
    if ttl_value is not None:
        updates["ttl"] = ttl_value
    self._table.update_item(pk=pk, sk=sk, updates=updates)

    # Set TTL on child items (tags, params, metrics)
    children = self._table.query(pk=pk, sk_prefix=f"{SK_LM_PREFIX}{model_id}#")
    for child in children:
        if ttl_value is not None:
            self._table.update_item(pk=pk, sk=child["SK"], updates={"ttl": ttl_value})

    # Set TTL on RANK items built from metric sub-items
    metric_children = [c for c in children if SK_LM_METRIC_PREFIX in c["SK"]]
    for mc in metric_children:
        inv_value = self._invert_metric_value(float(mc["metric_value"]))
        rank_sk = f"{SK_RANK_LM_PREFIX}{mc['metric_name']}#{inv_value}#{model_id}"
        if ttl_value is not None:
            self._table.update_item(pk=pk, sk=rank_sk, updates={"ttl": ttl_value})
        if mc.get("dataset_name") and mc.get("dataset_digest"):
            rank_sk_ds = (
                f"{SK_RANK_LMD_PREFIX}{mc['metric_name']}#"
                f"{mc['dataset_name']}#{mc['dataset_digest']}#{inv_value}#{model_id}"
            )
            if ttl_value is not None:
                self._table.update_item(pk=pk, sk=rank_sk_ds, updates={"ttl": ttl_value})

set_logged_model_tags

set_logged_model_tags(model_id, tags)

Set (or overwrite) tags on a logged model.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def set_logged_model_tags(self, model_id: str, tags: list[LoggedModelTag]) -> None:
    """Set (or overwrite) tags on a logged model."""
    experiment_id = self._resolve_logged_model_experiment(model_id)
    pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"
    now_ms = int(time.time() * 1000)

    tag_dict: dict[str, str] = {}
    for tag in tags:
        self._table.put_item(
            {
                "PK": pk,
                "SK": f"{SK_LM_PREFIX}{model_id}{SK_LM_TAG_PREFIX}{tag.key}",
                "key": tag.key,
                "value": tag.value,
            }
        )
        tag_dict[tag.key] = tag.value

    # Update denormalized tags on META item
    meta = self._table.get_item(pk=pk, sk=f"{SK_LM_PREFIX}{model_id}") or {}
    existing_tags: dict[str, str] = meta.get("tags", {})
    existing_tags.update(tag_dict)
    self._table.update_item(
        pk=pk,
        sk=f"{SK_LM_PREFIX}{model_id}",
        updates={"tags": existing_tags, "last_updated_timestamp_ms": now_ms},
    )

delete_logged_model_tag

delete_logged_model_tag(model_id, key)

Delete a tag from a logged model.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def delete_logged_model_tag(self, model_id: str, key: str) -> None:
    """Delete a tag from a logged model."""
    experiment_id = self._resolve_logged_model_experiment(model_id)
    pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"
    now_ms = int(time.time() * 1000)

    tag_sk = f"{SK_LM_PREFIX}{model_id}{SK_LM_TAG_PREFIX}{key}"
    existing = self._table.get_item(pk=pk, sk=tag_sk)
    if existing is None:
        raise MlflowException(
            f"No tag with key '{key}' for logged model '{model_id}'.",
            error_code=RESOURCE_DOES_NOT_EXIST,
        )
    self._table.delete_item(pk=pk, sk=tag_sk)

    # Update denormalized tags on META item
    meta = self._table.get_item(pk=pk, sk=f"{SK_LM_PREFIX}{model_id}") or {}
    existing_tags: dict[str, str] = meta.get("tags", {})
    existing_tags.pop(key, None)
    self._table.update_item(
        pk=pk,
        sk=f"{SK_LM_PREFIX}{model_id}",
        updates={"tags": existing_tags, "last_updated_timestamp_ms": now_ms},
    )

record_logged_model

record_logged_model(run_id, mlflow_model)

Append a model info dict to the run's mlflow.loggedModels tag.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def record_logged_model(self, run_id: str, mlflow_model: dict[str, Any]) -> None:
    """Append a model info dict to the run's mlflow.loggedModels tag."""
    import json as _json

    run = self.get_run(run_id)
    experiment_id = run.info.experiment_id
    pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"

    tag_sk = f"{SK_RUN_PREFIX}{run_id}{SK_TAG_PREFIX}mlflow.loggedModels"
    existing = self._table.get_item(pk=pk, sk=tag_sk)
    models = _json.loads(existing["value"]) if existing else []
    models.append(mlflow_model)
    serialized = _json.dumps(models)

    self._table.put_item(
        {
            "PK": pk,
            "SK": tag_sk,
            "key": "mlflow.loggedModels",
            "value": serialized,
        }
    )

    # Update denormalized tags on run META
    run_sk = f"{SK_RUN_PREFIX}{run_id}"
    meta = self._table.get_item(pk=pk, sk=run_sk)
    run_tags = meta.get("tags", {}) if meta else {}
    run_tags["mlflow.loggedModels"] = serialized
    self._table.update_item(pk=pk, sk=run_sk, updates={"tags": run_tags})

search_logged_models

search_logged_models(experiment_ids, filter_string=None, datasets=None, max_results=None, order_by=None, page_token=None)

Search logged models across experiments using parse/plan/execute pipeline.

Collects all matching items from every experiment first, then paginates the merged result using an offset-based token.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def search_logged_models(
    self,
    experiment_ids: list[str],
    filter_string: str | None = None,
    datasets: list[dict[str, Any]] | None = None,
    max_results: int | None = None,
    order_by: list[dict[str, Any]] | None = None,
    page_token: str | None = None,
) -> PagedList[LoggedModel]:
    """Search logged models across experiments using parse/plan/execute pipeline.

    Collects all matching items from every experiment first, then paginates
    the merged result using an offset-based token.
    """
    from mlflow_dynamodbstore.dynamodb.pagination import decode_page_token, encode_page_token
    from mlflow_dynamodbstore.dynamodb.search import (
        execute_logged_model_query,
        parse_logged_model_filter,
        plan_logged_model_query,
    )

    max_results = max_results or 100

    # Validate filter string using MLflow's parser (raises with standard error message)
    if filter_string:
        from mlflow.utils.search_logged_model_utils import parse_filter_string

        parse_filter_string(filter_string)

    predicates = parse_logged_model_filter(filter_string)
    plan = plan_logged_model_query(predicates, order_by, datasets)

    # Collect all matching items across experiments (pagination applied later)
    all_items: list[dict[str, Any]] = []
    for exp_id in experiment_ids:
        pk = f"{PK_EXPERIMENT_PREFIX}{exp_id}"
        items = execute_logged_model_query(self._table, plan, pk, predicates)
        all_items.extend(items)

    # Convert items to LoggedModel entities
    models: list[LoggedModel] = []
    for item in all_items:
        exp_id = item["experiment_id"]
        model_id = item["model_id"]
        pk = f"{PK_EXPERIMENT_PREFIX}{exp_id}"
        tag_items = self._table.query(
            pk=pk, sk_prefix=f"{SK_LM_PREFIX}{model_id}{SK_LM_TAG_PREFIX}"
        )
        param_items = self._table.query(
            pk=pk, sk_prefix=f"{SK_LM_PREFIX}{model_id}{SK_LM_PARAM_PREFIX}"
        )
        metric_items = self._table.query(
            pk=pk, sk_prefix=f"{SK_LM_PREFIX}{model_id}{SK_LM_METRIC_PREFIX}"
        )
        models.append(_item_to_logged_model(item, tag_items, param_items, metric_items))

    # Apply offset-based pagination across the merged result
    token_data = decode_page_token(page_token)
    offset = token_data.get("offset", 0) if token_data else 0
    page = models[offset : offset + max_results]
    has_more = len(models) > offset + max_results
    next_token = encode_page_token({"offset": offset + max_results}) if has_more else None

    return PagedList(page, next_token)

update_run_info

update_run_info(run_id, run_status, end_time, run_name)

Update run status, end_time, and run_name.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def update_run_info(
    self, run_id: str, run_status: str | int, end_time: int, run_name: str
) -> RunInfo:
    """Update run status, end_time, and run_name."""
    from mlflow.entities import RunStatus

    # MLflow may pass status as protobuf enum int (e.g. 3 for FINISHED)
    if isinstance(run_status, int):
        run_status = RunStatus.to_string(run_status)

    experiment_id = self._resolve_run_experiment(run_id)

    # Get current run to compute duration
    current = self._table.get_item(
        pk=f"{PK_EXPERIMENT_PREFIX}{experiment_id}",
        sk=f"{SK_RUN_PREFIX}{run_id}",
    )
    if current is None:
        raise MlflowException(
            f"Run '{run_id}' does not exist.",
            error_code=RESOURCE_DOES_NOT_EXIST,
        )

    updates: dict[str, Any] = {
        "status": run_status,
        LSI3_SK: f"{run_status}#{run_id}",
    }
    removes: list[str] = []
    if run_name:
        updates["run_name"] = run_name
        updates[LSI4_SK] = run_name.lower()

    if end_time is not None:
        updates["end_time"] = end_time
        updates[LSI2_SK] = end_time
        start_time = current.get("start_time", 0)
        if start_time:
            updates[LSI5_SK] = str(end_time - start_time)

    updated_item = self._table.update_item(
        pk=f"{PK_EXPERIMENT_PREFIX}{experiment_id}",
        sk=f"{SK_RUN_PREFIX}{run_id}",
        updates=updates,
        removes=removes if removes else None,
    )

    assert updated_item is not None  # update always returns ALL_NEW

    # Update mlflow.runName tag when name changes
    old_run_name = current.get("run_name", "")
    if run_name and run_name != old_run_name:
        from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME

        self._write_run_tag(experiment_id, run_id, RunTag(MLFLOW_RUN_NAME, run_name))

    # Update FTS for run name if it changed
    if run_name and run_name != old_run_name:
        pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"
        self._update_fts_for_rename(
            pk=pk,
            entity_type="R",
            entity_id=run_id,
            field=None,
            old_text=old_run_name or None,
            new_text=run_name,
            workspace=None,
        )

    return _item_to_run_info(updated_item)

delete_run

delete_run(run_id)

Soft-delete a run and set TTL on all related items.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def delete_run(self, run_id: str) -> None:
    """Soft-delete a run and set TTL on all related items."""
    experiment_id = self._resolve_run_experiment(run_id)
    pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"

    # Compute TTL if enabled
    ttl_seconds = self._config.get_soft_deleted_ttl_seconds()
    ttl_value = int(time.time()) + ttl_seconds if ttl_seconds is not None else None

    # Update run META item
    updates: dict[str, Any] = {
        "lifecycle_stage": "deleted",
        LSI1_SK: f"deleted#{run_id}",
    }
    if ttl_value is not None:
        updates["ttl"] = ttl_value

    self._table.update_item(pk=pk, sk=f"{SK_RUN_PREFIX}{run_id}", updates=updates)

    # Set TTL on related items if enabled
    if ttl_value is not None:
        self._set_ttl_on_run_related_items(pk, run_id, ttl_value)

restore_run

restore_run(run_id)

Restore a soft-deleted run and remove TTL from all related items.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def restore_run(self, run_id: str) -> None:
    """Restore a soft-deleted run and remove TTL from all related items."""
    experiment_id = self._resolve_run_experiment(run_id)
    pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"

    # Update run META: restore lifecycle and remove TTL
    self._table.update_item(
        pk=pk,
        sk=f"{SK_RUN_PREFIX}{run_id}",
        updates={
            "lifecycle_stage": "active",
            LSI1_SK: f"active#{run_id}",
        },
        removes=["ttl"],
    )

    # Remove TTL from all related items
    self._remove_ttl_from_run_related_items(pk, run_id)

log_metric

log_metric(run_id, metric)

Log a single metric, validating before batch.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def log_metric(self, run_id: str, metric: Metric) -> None:
    """Log a single metric, validating before batch."""
    from mlflow.utils.validation import _validate_metric

    _validate_metric(metric.key, metric.value, metric.timestamp, metric.step)
    self.log_batch(run_id, metrics=[metric], params=[], tags=[])

log_batch

log_batch(run_id, metrics, params, tags)

Log a batch of metrics, params, and tags for a run.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def log_batch(
    self,
    run_id: str,
    metrics: list[Metric],
    params: list[Param],
    tags: list[RunTag],
) -> None:
    """Log a batch of metrics, params, and tags for a run."""
    from mlflow.utils.validation import _validate_batch_log_data

    metrics, params, tags = _validate_batch_log_data(metrics, params, tags)

    experiment_id = self._resolve_run_experiment(run_id)
    pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"

    items: list[dict[str, Any]] = []

    for metric in metrics:
        # Latest metric item (overwrites previous for same key)
        latest_sk = f"{SK_RUN_PREFIX}{run_id}{SK_METRIC_PREFIX}{metric.key}"
        ddb_value = Decimal(str(metric.value))
        items.append(
            {
                "PK": pk,
                "SK": latest_sk,
                "key": metric.key,
                "value": ddb_value,
                "timestamp": metric.timestamp,
                "step": metric.step,
            }
        )

        # History item (unique per key+step+timestamp)
        padded = pad_step(metric.step)
        hist_sk = (
            f"{SK_RUN_PREFIX}{run_id}{SK_METRIC_HISTORY_PREFIX}"
            f"{metric.key}#{padded}#{metric.timestamp}"
        )
        hist_item: dict[str, Any] = {
            "PK": pk,
            "SK": hist_sk,
            "key": metric.key,
            "value": ddb_value,
            "timestamp": metric.timestamp,
            "step": metric.step,
        }
        metric_history_ttl = self._config.get_metric_history_ttl_seconds()
        if metric_history_ttl is not None:
            hist_item["ttl"] = int(time.time()) + metric_history_ttl
        items.append(hist_item)

        # RANK item for metric (inverted value for descending sort)
        max_val = 9999999999.9999
        inv = max_val - float(metric.value)
        inv_str = f"{inv:020.4f}"
        rank_sk = f"{SK_RANK_PREFIX}m#{metric.key}#{inv_str}#{run_id}"
        items.append(
            {
                "PK": pk,
                "SK": rank_sk,
                "key": metric.key,
                "value": ddb_value,
                "run_id": run_id,
            }
        )

    for param in params:
        param_sk = f"{SK_RUN_PREFIX}{run_id}{SK_PARAM_PREFIX}{param.key}"
        items.append(
            {
                "PK": pk,
                "SK": param_sk,
                "key": param.key,
                "value": param.value,
            }
        )

        # RANK item for param
        rank_sk = f"{SK_RANK_PREFIX}p#{param.key}#{param.value}#{run_id}"
        items.append(
            {
                "PK": pk,
                "SK": rank_sk,
                "key": param.key,
                "value": param.value,
                "run_id": run_id,
            }
        )

    # Write all items in batch
    if items:
        self._table.batch_write(items)

    # Write FTS items for param values if configured
    if self._config.should_trigram("run_param_value"):
        fts_param_items: list[dict[str, Any]] = []
        for param in params:
            if param.value:
                fts_param_items.extend(
                    fts_items_for_text(
                        pk=pk,
                        entity_type="R",
                        entity_id=run_id,
                        field=param.key,
                        text=param.value,
                    )
                )
        if fts_param_items:
            self._table.batch_write(fts_param_items)

    # Write tags individually (uses put_item)
    for tag in tags:
        self._write_run_tag(experiment_id, run_id, tag)

log_outputs

log_outputs(run_id, models)

Associate logged model outputs with a run.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def log_outputs(self, run_id: str, models: list[LoggedModelOutput]) -> None:
    """Associate logged model outputs with a run."""
    if not models:
        return

    experiment_id = self._resolve_run_experiment(run_id)
    pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"

    # Verify run is active (not deleted)
    meta = self._table.get_item(pk=pk, sk=f"{SK_RUN_PREFIX}{run_id}")
    if meta is None:
        raise MlflowException(
            f"Run '{run_id}' does not exist.",
            error_code=RESOURCE_DOES_NOT_EXIST,
        )
    if meta.get("lifecycle_stage") == "deleted":
        raise MlflowException(
            f"Run '{run_id}' is deleted.",
            error_code=RESOURCE_DOES_NOT_EXIST,
        )

    items: list[dict[str, Any]] = []
    for model in models:
        output_id = generate_ulid()
        items.append(
            {
                "PK": pk,
                "SK": f"{SK_RUN_PREFIX}{run_id}{SK_OUTPUT_PREFIX}{output_id}",
                "source_type": "RUN_OUTPUT",
                "source_id": run_id,
                "destination_type": "MODEL_OUTPUT",
                "destination_id": model.model_id,
                "step": model.step,
            }
        )

    self._table.batch_write(items)

get_metric_history

get_metric_history(run_id, metric_key, max_results=None, page_token=None)

Return the history of a metric for a run, ordered by step.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def get_metric_history(
    self,
    run_id: str,
    metric_key: str,
    max_results: int | None = None,
    page_token: str | None = None,
) -> list[Metric]:
    """Return the history of a metric for a run, ordered by step."""
    experiment_id = self._resolve_run_experiment(run_id)
    pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"
    prefix = f"{SK_RUN_PREFIX}{run_id}{SK_METRIC_HISTORY_PREFIX}{metric_key}#"

    items = self._table.query(
        pk=pk,
        sk_prefix=prefix,
        limit=max_results,
    )

    from mlflow.store.entities import PagedList

    metrics = [
        Metric(
            key=item["key"],
            value=float(item["value"]),
            timestamp=int(item.get("timestamp", 0)),
            step=int(item.get("step", 0)),
        )
        for item in items
    ]
    return PagedList(metrics, token=None)

get_metric_history_bulk_interval_from_steps

get_metric_history_bulk_interval_from_steps(run_id, metric_key, steps, max_results)

Return metric history for specific steps, optimized for DynamoDB.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def get_metric_history_bulk_interval_from_steps(
    self, run_id: str, metric_key: str, steps: list[int], max_results: int
) -> list[Any]:
    """Return metric history for specific steps, optimized for DynamoDB."""
    from mlflow.entities.metric import MetricWithRunId

    if not steps:
        return []

    experiment_id = self._resolve_run_experiment(run_id)
    pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"
    prefix = f"{SK_RUN_PREFIX}{run_id}{SK_METRIC_HISTORY_PREFIX}{metric_key}#"

    steps_set = set(steps)
    items = self._table.query(pk=pk, sk_prefix=prefix)

    metrics = sorted(
        [
            Metric(
                key=item["key"],
                value=float(item["value"]),
                timestamp=int(item.get("timestamp", 0)),
                step=int(item.get("step", 0)),
            )
            for item in items
            if int(item.get("step", 0)) in steps_set
        ],
        key=lambda m: (m.step, m.timestamp),
    )[:max_results]

    return [MetricWithRunId(run_id=run_id, metric=m) for m in metrics]

set_tag

set_tag(run_id, tag)

Set a tag on a run.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def set_tag(self, run_id: str, tag: RunTag) -> None:
    """Set a tag on a run."""
    from mlflow.utils.validation import _validate_tag_name

    _validate_tag_name(tag.key)
    experiment_id = self._resolve_run_experiment(run_id)
    self._write_run_tag(experiment_id, run_id, tag)

delete_tag

delete_tag(run_id, key)

Delete a tag from a run.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def delete_tag(self, run_id: str, key: str) -> None:
    """Delete a tag from a run."""
    experiment_id = self._resolve_run_experiment(run_id)
    pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"
    sk = f"{SK_RUN_PREFIX}{run_id}{SK_TAG_PREFIX}{key}"
    self._table.delete_item(pk=pk, sk=sk)
    if self._config.should_denormalize(experiment_id, key):
        self._remove_denormalized_tag(pk, f"{SK_RUN_PREFIX}{run_id}", key)
    # Remove FTS items for tag value if FTS was configured
    if self._config.should_trigram("run_tag_value"):
        self._delete_fts_for_entity_field(pk=pk, entity_type="R", entity_id=run_id, field=key)

log_inputs

log_inputs(run_id, datasets=None, models=None)

Log dataset inputs for a run.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def log_inputs(
    self,
    run_id: str,
    datasets: list[DatasetInput] | None = None,
    models: Any = None,
) -> None:
    """Log dataset inputs for a run."""
    if not datasets:
        return

    experiment_id = self._resolve_run_experiment(run_id)
    pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"
    items: list[dict[str, Any]] = []

    for dataset_input in datasets:
        ds = dataset_input.dataset
        ds_uuid = generate_ulid()

        # Dataset item: PK=EXP#<exp_id>, SK=D#<name>#<digest>
        items.append(
            {
                "PK": pk,
                "SK": f"{SK_DATASET_PREFIX}{ds.name}#{ds.digest}",
                "name": ds.name,
                "digest": ds.digest,
                "source_type": ds.source_type,
                "source": ds.source,
                "schema": ds.schema,
                "profile": ds.profile,
            }
        )

        # Input link item: PK=EXP#<exp_id>, SK=R#<run_id>#INPUT#<ds_uuid>
        items.append(
            {
                "PK": pk,
                "SK": f"{SK_RUN_PREFIX}{run_id}{SK_INPUT_PREFIX}{ds_uuid}",
                "dataset_name": ds.name,
                "dataset_digest": ds.digest,
            }
        )

        # Input tag items and extract context tag
        context: str | None = None
        for tag in dataset_input.tags:
            items.append(
                {
                    "PK": pk,
                    "SK": (
                        f"{SK_RUN_PREFIX}{run_id}{SK_INPUT_PREFIX}{ds_uuid}"
                        f"{SK_INPUT_TAG_SUFFIX}{tag.key}"
                    ),
                    "key": tag.key,
                    "value": tag.value,
                }
            )
            if tag.key == "mlflow.data.context":
                context = tag.value

        # DLINK materialized item: PK=EXP#<exp_id>, SK=DLINK#<name>#<digest>#R#<run_id>
        dlink_item: dict[str, Any] = {
            "PK": pk,
            "SK": f"{SK_DLINK_PREFIX}{ds.name}#{ds.digest}#{SK_RUN_PREFIX}{run_id}",
            "dataset_name": ds.name,
            "dataset_digest": ds.digest,
            "run_id": run_id,
        }
        if context is not None:
            dlink_item["context"] = context
        items.append(dlink_item)

    self._table.batch_write(items)

start_trace

start_trace(trace_info)

Create a trace in DynamoDB from a TraceInfo object.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def start_trace(self, trace_info: TraceInfo) -> TraceInfo:
    """Create a trace in DynamoDB from a TraceInfo object."""
    mlflow_exp = trace_info.trace_location.mlflow_experiment
    if mlflow_exp is None:
        raise MlflowException(
            "TraceInfo must have an MLflow experiment location.",
            error_code=RESOURCE_DOES_NOT_EXIST,
        )
    experiment_id = mlflow_exp.experiment_id
    trace_id = trace_info.trace_id
    pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"
    sk = f"{SK_TRACE_PREFIX}{trace_id}"

    ttl = self._get_trace_ttl()

    # Extract trace name from trace_metadata or tags
    trace_name = trace_info.trace_metadata.get(
        TraceTagKey.TRACE_NAME, ""
    ) or trace_info.tags.get(TraceTagKey.TRACE_NAME, "")
    execution_duration = trace_info.execution_duration or 0
    request_time = trace_info.request_time
    state_str = str(trace_info.state)

    # Build META item
    item: dict[str, Any] = {
        "PK": pk,
        "SK": sk,
        "trace_id": trace_id,
        "experiment_id": experiment_id,
        "request_time": request_time,
        "execution_duration": execution_duration,
        "state": state_str,
        "tags": {},
        # LSI attributes (must be strings, zero-padded for sort order)
        LSI1_SK: f"{request_time:020d}",
        LSI2_SK: request_time + execution_duration,
        LSI3_SK: f"{state_str}#{request_time:020d}",
        LSI5_SK: f"{execution_duration:020d}",
        # GSI1: reverse lookup trace_id -> experiment_id
        GSI1_PK: f"{GSI1_TRACE_PREFIX}{trace_id}",
        GSI1_SK: f"{PK_EXPERIMENT_PREFIX}{experiment_id}",
    }

    if trace_name:
        item[LSI4_SK] = trace_name.lower()

    if ttl is not None:
        item["ttl"] = ttl

    if trace_info.client_request_id:
        item["client_request_id"] = trace_info.client_request_id

    try:
        self._table.put_item(item, condition="attribute_not_exists(PK)")
    except Exception:
        # Trace META may already exist (created by log_spans) — update instead
        update_fields = {k: v for k, v in item.items() if k not in ("PK", "SK")}
        self._table.update_item(pk=pk, sk=sk, updates=update_fields)

    # Cache trace_id -> experiment_id
    self._cache.put("trace_exp", trace_id, experiment_id)

    # Write CLIENTPTR item if client_request_id is present
    if trace_info.client_request_id:
        ptr_item: dict[str, Any] = {
            "PK": pk,
            "SK": f"{SK_TRACE_PREFIX}{trace_id}#CLIENTPTR",
            GSI1_PK: f"{GSI1_CLIENT_PREFIX}{trace_info.client_request_id}",
            GSI1_SK: f"{GSI1_TRACE_PREFIX}{trace_id}",
        }
        if ttl is not None:
            ptr_item["ttl"] = ttl
        self._table.put_item(ptr_item)

    # Write request metadata items
    if trace_info.trace_metadata:
        rmeta_items: list[dict[str, Any]] = []
        for key, value in trace_info.trace_metadata.items():
            rmeta_item: dict[str, Any] = {
                "PK": pk,
                "SK": f"{SK_TRACE_PREFIX}{trace_id}#RMETA#{key}",
                "key": key,
                "value": value,
            }
            if ttl is not None:
                rmeta_item["ttl"] = ttl
            rmeta_items.append(rmeta_item)
        if rmeta_items:
            self._table.batch_write(rmeta_items)

    # Write initial tag items + denormalization + FTS
    if trace_info.tags:
        for tag_key, tag_value in trace_info.tags.items():
            self._write_trace_tag(experiment_id, trace_id, tag_key, tag_value, ttl)

    # Ensure artifact location tag is set (required by MLflow trace export).
    # Matches SQLAlchemy store: always compute from experiment artifact_location.
    if MLFLOW_ARTIFACT_LOCATION not in (trace_info.tags or {}):
        exp = self.get_experiment(experiment_id)
        artifact_loc = append_to_uri_path(
            exp.artifact_location, "traces", trace_id, "artifacts"
        )
        self._write_trace_tag(
            experiment_id, trace_id, MLFLOW_ARTIFACT_LOCATION, artifact_loc, ttl
        )
        if trace_info.tags is None:
            trace_info.tags = {}
        trace_info.tags[MLFLOW_ARTIFACT_LOCATION] = artifact_loc

    # Upsert session tracker if trace has session metadata
    session_id = (trace_info.trace_metadata or {}).get("mlflow.traceSession")
    if session_id:
        self._upsert_session_tracker(
            experiment_id=experiment_id,
            session_id=session_id,
            timestamp_ms=trace_info.request_time,
            ttl=ttl,
        )

    return trace_info

get_trace_info

get_trace_info(trace_id)

Fetch a trace by ID, reconstructing TraceInfo from DynamoDB items.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def get_trace_info(self, trace_id: str) -> TraceInfo:
    """Fetch a trace by ID, reconstructing TraceInfo from DynamoDB items."""
    experiment_id = self._resolve_trace_experiment(trace_id)
    pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"
    sk = f"{SK_TRACE_PREFIX}{trace_id}"

    meta = self._table.get_item(pk=pk, sk=sk)
    if meta is None:
        raise MlflowException(
            f"Trace with ID {trace_id} is not found.",
            error_code=RESOURCE_DOES_NOT_EXIST,
        )

    # Read tags
    tag_items = self._table.query(
        pk=pk,
        sk_prefix=f"{SK_TRACE_PREFIX}{trace_id}#TAG#",
    )
    tags = {item["key"]: item["value"] for item in tag_items}

    # Read request metadata
    rmeta_items = self._table.query(
        pk=pk,
        sk_prefix=f"{SK_TRACE_PREFIX}{trace_id}#RMETA#",
    )
    trace_metadata = {item["key"]: item["value"] for item in rmeta_items}

    state_str = meta.get("state", "STATE_UNSPECIFIED")
    state = TraceState(state_str)

    assessments = self._load_assessments(pk, trace_id)

    return TraceInfo(
        trace_id=trace_id,
        trace_location=TraceLocation(
            type=TraceLocationType.MLFLOW_EXPERIMENT,
            mlflow_experiment=MlflowExperimentLocation(experiment_id=experiment_id),
        ),
        request_time=int(meta["request_time"]),
        execution_duration=int(meta.get("execution_duration", 0)),
        state=state,
        trace_metadata=trace_metadata,
        tags=tags,
        client_request_id=meta.get("client_request_id"),
        assessments=assessments,
    )

get_trace

get_trace(trace_id, *, allow_partial=False)

Fetch a trace with spans.

Flow: 1. Read trace info (META + tags + metadata + assessments) from DynamoDB 2. Check for cached spans: T##SPANS item 3. If cached -> deserialize and use them 4. If not cached -> call XRayClient.batch_get_traces([trace_id]) 5. Convert X-Ray segments -> span dicts via span_converter 6. Cache to DynamoDB: T##SPANS (JSON blob, same TTL as trace) 7. Denormalize span attributes on META: span_types, span_statuses, span_names 8. Write FTS items for span names 9. Return complete Trace

Source code in src/mlflow_dynamodbstore/tracking_store.py
def get_trace(
    self,
    trace_id: str,
    *,
    allow_partial: bool = False,
) -> Trace:
    """Fetch a trace with spans.

    Flow:
    1. Read trace info (META + tags + metadata + assessments) from DynamoDB
    2. Check for cached spans: T#<trace_id>#SPANS item
    3. If cached -> deserialize and use them
    4. If not cached -> call XRayClient.batch_get_traces([trace_id])
    5. Convert X-Ray segments -> span dicts via span_converter
    6. Cache to DynamoDB: T#<trace_id>#SPANS (JSON blob, same TTL as trace)
    7. Denormalize span attributes on META: span_types, span_statuses, span_names
    8. Write FTS items for span names
    9. Return complete Trace
    """
    import json as _json

    from mlflow.entities.trace import Trace
    from mlflow.entities.trace_data import TraceData

    from mlflow_dynamodbstore.xray.span_converter import (
        convert_xray_trace,
        span_dicts_to_mlflow_spans,
    )

    trace_info = self.get_trace_info(trace_id)
    assert trace_info.trace_location.mlflow_experiment is not None
    experiment_id = trace_info.trace_location.mlflow_experiment.experiment_id
    pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"
    sk = f"{SK_TRACE_PREFIX}{trace_id}"

    # Check for cached spans
    spans_sk = f"{SK_TRACE_PREFIX}{trace_id}#SPANS"
    cached = self._table.get_item(pk=pk, sk=spans_sk)

    span_dicts: list[dict[str, Any]]
    if cached is not None:
        span_dicts = _json.loads(cached["data"])
    else:
        # Fetch from X-Ray
        xray_traces = self._xray_client.batch_get_traces([trace_id])
        if xray_traces:
            span_dicts = convert_xray_trace(xray_traces[0])
        else:
            span_dicts = []

        # Get TTL from META
        meta = self._table.get_item(pk=pk, sk=sk)
        ttl = int(meta["ttl"]) if meta and "ttl" in meta else self._get_trace_ttl()

        # Cache the spans
        spans_item: dict[str, Any] = {
            "PK": pk,
            "SK": spans_sk,
            "data": _json.dumps(span_dicts),
        }
        if ttl is not None:
            spans_item["ttl"] = ttl
        self._table.put_item(spans_item)

        # Denormalize span attributes on META (skip if already done)
        if span_dicts and not (meta or {}).get("span_types"):
            span_types = set()
            span_statuses = set()
            span_names = set()
            for sd in span_dicts:
                if sd.get("span_type"):
                    span_types.add(sd["span_type"])
                if sd.get("status"):
                    span_statuses.add(sd["status"])
                if sd.get("name"):
                    span_names.add(sd["name"])

            updates: dict[str, Any] = {}
            if span_types:
                updates["span_types"] = span_types
            if span_statuses:
                updates["span_statuses"] = span_statuses
            if span_names:
                updates["span_names"] = span_names

            if updates:
                self._table.update_item(pk=pk, sk=sk, updates=updates)

            # Write FTS items for span names
            if span_names:
                span_names_text = " ".join(sorted(span_names))
                fts_items = fts_items_for_text(
                    pk=pk,
                    entity_type="T",
                    entity_id=trace_id,
                    field="spans",
                    text=span_names_text,
                )
                if ttl is not None:
                    for item in fts_items:
                        item["ttl"] = ttl
                self._table.batch_write(fts_items)

    # Convert span dicts to MLflow Span objects
    # Try V3 format (Span.to_dict) first — preserves original span IDs.
    # Fall back to X-Ray converter which hashes span IDs.
    if span_dicts and "start_time_unix_nano" in span_dicts[0]:
        try:
            from mlflow.entities.span import Span as SpanEntity

            spans = [SpanEntity.from_dict(sd) for sd in span_dicts]
        except Exception:
            spans = span_dicts_to_mlflow_spans(span_dicts, trace_id)
    else:
        spans = span_dicts_to_mlflow_spans(span_dicts, trace_id)

    return Trace(info=trace_info, data=TraceData(spans=spans))

log_spans

log_spans(location, spans, tracking_uri=None)

Log spans to the tracking store by writing SPANS cache items.

In addition to the SPANS JSON blob, writes: - Individual span items (T##SPAN#) - Trace metric items (T##TMETRIC#) for aggregated token usage - Span metric items (T##SMETRIC##) for per-span costs - META denormalization of span_types, span_names, span_statuses

Source code in src/mlflow_dynamodbstore/tracking_store.py
2784
2785
2786
2787
2788
2789
2790
2791
2792
2793
2794
2795
2796
2797
2798
2799
2800
2801
2802
2803
2804
2805
2806
2807
2808
2809
2810
2811
2812
2813
2814
2815
2816
2817
2818
2819
2820
2821
2822
2823
2824
2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
2845
2846
2847
2848
2849
2850
2851
2852
2853
2854
2855
2856
2857
2858
2859
2860
2861
2862
2863
2864
2865
2866
2867
2868
2869
2870
2871
2872
2873
2874
2875
2876
2877
2878
2879
2880
2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
2891
2892
2893
2894
2895
2896
2897
2898
2899
2900
2901
2902
2903
2904
2905
2906
2907
2908
2909
2910
2911
2912
2913
2914
2915
2916
2917
2918
2919
2920
2921
2922
2923
2924
2925
2926
2927
2928
2929
2930
2931
2932
2933
2934
2935
2936
2937
2938
2939
2940
2941
2942
2943
2944
2945
2946
2947
2948
2949
2950
2951
2952
2953
2954
2955
2956
2957
2958
2959
2960
2961
2962
2963
2964
2965
2966
2967
2968
2969
2970
2971
2972
2973
2974
2975
2976
2977
2978
2979
2980
2981
2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996
2997
2998
2999
3000
3001
3002
3003
3004
3005
3006
3007
3008
3009
3010
3011
3012
3013
3014
3015
3016
3017
3018
3019
3020
3021
3022
3023
3024
3025
3026
3027
3028
3029
3030
3031
3032
3033
3034
3035
3036
3037
3038
3039
3040
3041
3042
3043
3044
3045
3046
3047
3048
3049
3050
3051
3052
3053
3054
3055
3056
3057
3058
3059
3060
3061
3062
3063
3064
3065
3066
3067
3068
3069
3070
3071
3072
3073
3074
3075
3076
3077
3078
3079
3080
3081
3082
3083
3084
3085
3086
3087
3088
3089
3090
3091
3092
3093
3094
3095
3096
3097
3098
3099
3100
3101
3102
3103
3104
3105
3106
3107
3108
3109
3110
3111
3112
3113
3114
3115
3116
3117
3118
3119
3120
3121
3122
3123
3124
3125
3126
3127
3128
3129
3130
3131
3132
3133
3134
3135
def log_spans(
    self, location: str, spans: list[Any], tracking_uri: str | None = None
) -> list[Any]:
    """Log spans to the tracking store by writing SPANS cache items.

    In addition to the SPANS JSON blob, writes:
    - Individual span items (T#<trace_id>#SPAN#<span_id>)
    - Trace metric items (T#<trace_id>#TMETRIC#<key>) for aggregated token usage
    - Span metric items (T#<trace_id>#SMETRIC#<span_id>#<key>) for per-span costs
    - META denormalization of span_types, span_names, span_statuses
    """
    import json as _json
    from collections import defaultdict

    if not spans:
        return []

    # Group spans by trace_id
    spans_by_trace: dict[str, list[Any]] = defaultdict(list)
    for span in spans:
        spans_by_trace[span.trace_id].append(span)

    for trace_id, trace_spans in spans_by_trace.items():
        try:
            experiment_id = location or self._resolve_trace_experiment(trace_id)
        except MlflowException:
            if not location:
                continue  # Skip unresolvable traces
            experiment_id = location

        pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"
        sk = f"{SK_TRACE_PREFIX}{trace_id}"

        # Read TTL from trace META, create if not exists
        meta = self._table.get_item(pk=pk, sk=sk)
        ttl = int(meta["ttl"]) if meta and "ttl" in meta else self._get_trace_ttl()

        if meta is None:
            # Trace doesn't exist yet — create META from span data
            from mlflow.entities import TraceState

            min_start_ns = min(
                getattr(s, "start_time_ns", None) or s.to_dict().get("start_time_ns", 0)
                for s in trace_spans
            )
            request_time = min_start_ns // 1_000_000

            end_times = [
                getattr(s, "end_time_ns", None) or s.to_dict().get("end_time_ns")
                for s in trace_spans
            ]
            end_times = [t for t in end_times if t is not None and t > 0]
            max_end_ms = (max(end_times) // 1_000_000) if end_times else None
            execution_duration = (max_end_ms - request_time) if max_end_ms else 0

            root_status = self._get_trace_status_from_spans(trace_spans)
            state_str = root_status or TraceState.IN_PROGRESS.value

            meta_item: dict[str, Any] = {
                "PK": pk,
                "SK": sk,
                "trace_id": trace_id,
                "experiment_id": experiment_id,
                "request_time": request_time,
                "execution_duration": execution_duration,
                "state": state_str,
                "tags": {},
                LSI1_SK: f"{request_time:020d}",
                LSI2_SK: request_time + execution_duration,
                LSI3_SK: f"{state_str}#{request_time:020d}",
                LSI5_SK: f"{execution_duration:020d}",
                GSI1_PK: f"{GSI1_TRACE_PREFIX}{trace_id}",
                GSI1_SK: f"{PK_EXPERIMENT_PREFIX}{experiment_id}",
            }
            if ttl is not None:
                meta_item["ttl"] = ttl
            self._table.put_item(meta_item)
            meta = meta_item

            # Cache trace_id -> experiment_id
            self._cache.put("trace_exp", trace_id, experiment_id)

        span_dicts = [s.to_dict() for s in trace_spans]

        # Merge with existing cached spans (log_spans may be called multiple
        # times for the same trace — e.g., parent span first, child span second)
        spans_sk = f"{SK_TRACE_PREFIX}{trace_id}#SPANS"
        existing = self._table.get_item(pk=pk, sk=spans_sk)
        if existing is not None:
            existing_dicts = _json.loads(existing["data"])
            # Deduplicate by span_id
            existing_by_id = {sd.get("span_id"): sd for sd in existing_dicts}
            for sd in span_dicts:
                existing_by_id[sd.get("span_id")] = sd
            span_dicts = list(existing_by_id.values())

        spans_item: dict[str, Any] = {
            "PK": pk,
            "SK": spans_sk,
            "data": _json.dumps(span_dicts),
        }
        if ttl is not None:
            spans_item["ttl"] = ttl
        self._table.put_item(spans_item)

        # --- Write individual span items, metrics, and denormalize META ---
        extra_items: list[dict[str, Any]] = []
        span_types: set[str] = set()
        span_statuses: set[str] = set()
        span_names: set[str] = set()
        # Accumulators for trace-level token usage and cost
        total_input_tokens = 0
        total_output_tokens = 0
        total_total_tokens = 0
        has_token_usage = False
        total_input_cost = 0.0
        total_output_cost = 0.0
        total_total_cost = 0.0
        has_cost = False

        for span in trace_spans:
            sd = span.to_dict()
            attrs = sd.get("attributes", {})

            # Read span fields — prefer direct properties, fall back to dict
            span_id = getattr(span, "span_id", None) or sd.get("span_id")
            name = getattr(span, "name", None) or sd.get("name", "")
            span_type = getattr(span, "span_type", None) or sd.get("span_type", "")
            start_ns = getattr(span, "start_time_ns", None)
            if start_ns is None or not isinstance(start_ns, int | float):
                start_ns = sd.get("start_time_ns", sd.get("start_time_unix_nano", 0))
            end_ns = getattr(span, "end_time_ns", None)
            if end_ns is None or not isinstance(end_ns, int | float):
                end_ns = sd.get("end_time_ns", sd.get("end_time_unix_nano", 0))

            # Extract status string
            status = getattr(span, "status", None) or sd.get("status", "")
            if hasattr(status, "status_code"):
                status_str = str(status.status_code)
            elif isinstance(status, dict):
                status_str = str(status.get("code", status))
            else:
                status_str = str(status)

            # Collect for META denormalization
            if span_type:
                span_types.add(str(span_type))
            if status_str:
                span_statuses.add(status_str)
            if name:
                span_names.add(str(name))

            # Skip individual span item if we don't have a span_id
            if not span_id:
                continue

            # Build individual span item
            span_item: dict[str, Any] = {
                "PK": pk,
                "SK": f"{SK_TRACE_PREFIX}{trace_id}{SK_SPAN_PREFIX}{span_id}",
                "name": str(name),
                "type": str(span_type),
                "status": status_str,
                "start_time_ns": int(start_ns),
                "end_time_ns": int(end_ns),
                "duration_ms": (int(end_ns) - int(start_ns)) // 1_000_000,
            }
            model_name = attrs.get("mlflow.llm.model")
            if model_name:
                try:
                    model_name = _json.loads(model_name)
                except (TypeError, _json.JSONDecodeError):
                    pass
                if model_name:
                    span_item["model_name"] = model_name

            model_provider = attrs.get("mlflow.llm.provider")
            if model_provider:
                try:
                    model_provider = _json.loads(model_provider)
                except (TypeError, _json.JSONDecodeError):
                    pass
                if model_provider:
                    span_item["model_provider"] = model_provider

            if ttl is not None:
                span_item["ttl"] = ttl
            extra_items.append(span_item)

            # --- Token usage (aggregate to trace level) ---
            token_usage_raw = attrs.get("mlflow.chat.tokenUsage")
            if token_usage_raw:
                try:
                    token_usage = (
                        _json.loads(token_usage_raw)
                        if isinstance(token_usage_raw, str)
                        else token_usage_raw
                    )
                    total_input_tokens += int(token_usage.get("input_tokens", 0))
                    total_output_tokens += int(token_usage.get("output_tokens", 0))
                    total_total_tokens += int(token_usage.get("total_tokens", 0))
                    has_token_usage = True
                except (TypeError, _json.JSONDecodeError, ValueError):
                    pass

            # --- Per-span cost metrics ---
            cost_raw = attrs.get("mlflow.llm.cost")
            if cost_raw:
                try:
                    from decimal import Decimal as _Decimal

                    cost = _json.loads(cost_raw) if isinstance(cost_raw, str) else cost_raw
                    for cost_key in ("input_cost", "output_cost", "total_cost"):
                        if cost_key in cost and cost[cost_key] is not None:
                            cost_item: dict[str, Any] = {
                                "PK": pk,
                                "SK": (
                                    f"{SK_TRACE_PREFIX}{trace_id}"
                                    f"{SK_SPAN_METRIC_PREFIX}{span_id}#{cost_key}"
                                ),
                                "value": _Decimal(str(cost[cost_key])),
                                "key": cost_key,
                                "span_id": span_id,
                            }
                            if ttl is not None:
                                cost_item["ttl"] = ttl
                            extra_items.append(cost_item)
                    # Accumulate trace-level cost
                    total_input_cost += float(cost.get("input_cost", 0) or 0)
                    total_output_cost += float(cost.get("output_cost", 0) or 0)
                    total_total_cost += float(cost.get("total_cost", 0) or 0)
                    has_cost = True
                except (TypeError, _json.JSONDecodeError, ValueError):
                    pass

        # --- Write trace-level metric items ---
        if has_token_usage:
            from decimal import Decimal as _Decimal

            for metric_key, metric_val in [
                ("input_tokens", total_input_tokens),
                ("output_tokens", total_output_tokens),
                ("total_tokens", total_total_tokens),
            ]:
                tmetric_item: dict[str, Any] = {
                    "PK": pk,
                    "SK": (f"{SK_TRACE_PREFIX}{trace_id}{SK_TRACE_METRIC_PREFIX}{metric_key}"),
                    "value": _Decimal(str(metric_val)),
                    "key": metric_key,
                }
                if ttl is not None:
                    tmetric_item["ttl"] = ttl
                extra_items.append(tmetric_item)

        # --- Write trace-level token_usage and cost as RMETA items ---
        # Accumulate with existing values (log_spans may be called multiple times)
        if has_token_usage:
            existing_token = self._table.get_item(
                pk=pk, sk=f"{SK_TRACE_PREFIX}{trace_id}#RMETA#mlflow.trace.tokenUsage"
            )
            if existing_token:
                prev = _json.loads(existing_token["value"])
                total_input_tokens += int(prev.get("input_tokens", 0))
                total_output_tokens += int(prev.get("output_tokens", 0))
                total_total_tokens += int(prev.get("total_tokens", 0))
            token_data = _json.dumps(
                {
                    "input_tokens": total_input_tokens,
                    "output_tokens": total_output_tokens,
                    "total_tokens": total_total_tokens,
                }
            )
            rmeta_token: dict[str, Any] = {
                "PK": pk,
                "SK": f"{SK_TRACE_PREFIX}{trace_id}#RMETA#mlflow.trace.tokenUsage",
                "key": "mlflow.trace.tokenUsage",
                "value": token_data,
            }
            if ttl is not None:
                rmeta_token["ttl"] = ttl
            extra_items.append(rmeta_token)

        if has_cost:
            existing_cost = self._table.get_item(
                pk=pk, sk=f"{SK_TRACE_PREFIX}{trace_id}#RMETA#mlflow.trace.cost"
            )
            if existing_cost:
                prev = _json.loads(existing_cost["value"])
                total_input_cost += float(prev.get("input_cost", 0))
                total_output_cost += float(prev.get("output_cost", 0))
                total_total_cost += float(prev.get("total_cost", 0))
            cost_data = _json.dumps(
                {
                    "input_cost": total_input_cost,
                    "output_cost": total_output_cost,
                    "total_cost": total_total_cost,
                }
            )
            rmeta_cost: dict[str, Any] = {
                "PK": pk,
                "SK": f"{SK_TRACE_PREFIX}{trace_id}#RMETA#mlflow.trace.cost",
                "key": "mlflow.trace.cost",
                "value": cost_data,
            }
            if ttl is not None:
                rmeta_cost["ttl"] = ttl
            extra_items.append(rmeta_cost)

        # Write all extra items in batch
        if extra_items:
            self._table.batch_write(extra_items)

        # --- Denormalize span attributes on META ---
        updates: dict[str, Any] = {}
        if span_types:
            updates["span_types"] = span_types
        if span_statuses:
            updates["span_statuses"] = span_statuses
        if span_names:
            updates["span_names"] = span_names

        # Update trace state from root span if status changed
        root_status = self._get_trace_status_from_spans(trace_spans)
        current_state = meta.get("state", "")
        if root_status and root_status != current_state:
            from mlflow.entities import TraceState

            finalized = {TraceState.OK.value, TraceState.ERROR.value}
            if current_state not in finalized:
                updates["state"] = root_status
                request_time = int(meta.get("request_time", 0))
                updates[LSI3_SK] = f"{root_status}#{request_time:020d}"

        if updates:
            self._table.update_item(pk=pk, sk=sk, updates=updates)

        # Write FTS items for span names
        if span_names:
            span_names_text = " ".join(sorted(span_names))
            fts_items = fts_items_for_text(
                pk=pk,
                entity_type="T",
                entity_id=trace_id,
                field="spans",
                text=span_names_text,
            )
            if ttl is not None:
                for item in fts_items:
                    item["ttl"] = ttl
            self._table.batch_write(fts_items)

    return spans

log_spans_async async

log_spans_async(location, spans)

Async version of log_spans — delegates to synchronous implementation.

Source code in src/mlflow_dynamodbstore/tracking_store.py
async def log_spans_async(self, location: str, spans: list[Any]) -> list[Any]:
    """Async version of log_spans — delegates to synchronous implementation."""
    return self.log_spans(location, spans)

search_traces

search_traces(experiment_ids=None, filter_string=None, max_results=100, order_by=None, page_token=None, model_id=None, locations=None)

Search traces across experiments using parse -> plan -> execute pipeline.

For span-level filters (field_type == "span"), uses a hybrid approach: 1. Cached traces (those with denormalized span_types/span_names/span_statuses on META) are filtered via DynamoDB. 2. Uncached traces are found via X-Ray filter expressions. 3. Results are unioned and deduplicated.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def search_traces(
    self,
    experiment_ids: list[str] | None = None,
    filter_string: str | None = None,
    max_results: int = 100,
    order_by: list[str] | None = None,
    page_token: str | None = None,
    model_id: str | None = None,
    locations: list[str] | None = None,
) -> tuple[list[TraceInfo], str | None]:
    """Search traces across experiments using parse -> plan -> execute pipeline.

    For span-level filters (field_type == "span"), uses a hybrid approach:
    1. Cached traces (those with denormalized span_types/span_names/span_statuses
       on META) are filtered via DynamoDB.
    2. Uncached traces are found via X-Ray filter expressions.
    3. Results are unioned and deduplicated.
    """
    from mlflow_dynamodbstore.dynamodb.pagination import (
        decode_page_token,
        encode_page_token,
    )
    from mlflow_dynamodbstore.dynamodb.search import (
        execute_trace_query,
        parse_trace_filter,
        plan_trace_query,
    )

    if not experiment_ids:
        experiment_ids = locations or []

    # 1. Parse filter and split into span vs non-span predicates
    predicates = parse_trace_filter(filter_string)
    span_predicates = [p for p in predicates if p.field_type == "span"]
    non_span_predicates = [p for p in predicates if p.field_type != "span"]

    # 2. Plan query using only non-span predicates
    plan = plan_trace_query(non_span_predicates, order_by)

    # 3. For each experiment: execute query
    token_data = decode_page_token(page_token)
    exp_idx = token_data.get("exp_idx", 0) if token_data else 0
    inner_token = token_data.get("inner_token") if token_data else None

    traces: list[TraceInfo] = []
    remaining = max_results
    next_page_token: str | None = None
    seen_trace_ids: set[str] = set()

    for i, exp_id in enumerate(experiment_ids[exp_idx:], start=exp_idx):
        pk = f"{PK_EXPERIMENT_PREFIX}{exp_id}"
        current_token = inner_token if i == exp_idx else None

        # --- Phase 1: DynamoDB query (handles non-span + cached span filters) ---
        inner_next: str | None = None
        while remaining > 0:
            items, inner_next = execute_trace_query(
                table=self._table,
                plan=plan,
                pk=pk,
                max_results=remaining if not span_predicates else remaining * 3,
                page_token=current_token,
                predicates=non_span_predicates,
            )

            for item in items:
                trace_id = item["trace_id"]
                if trace_id in seen_trace_ids:
                    continue

                # Apply span predicates on cached (denormalized) data
                if span_predicates and not self._match_span_predicates_cached(
                    item, span_predicates
                ):
                    continue

                seen_trace_ids.add(trace_id)
                self._cache.put("trace_exp", trace_id, exp_id)
                trace_info = self._build_trace_info(exp_id, trace_id, item)
                traces.append(trace_info)
                remaining -= 1

                if remaining <= 0:
                    break

            if not inner_next:
                break
            current_token = inner_next

        # --- Phase 2: X-Ray fallback for uncached traces with span filters ---
        if span_predicates and remaining > 0:
            xray_trace_ids = self._search_xray_for_span_filters(exp_id, span_predicates)
            for xray_tid in xray_trace_ids:
                if remaining <= 0:
                    break
                if xray_tid in seen_trace_ids:
                    continue
                # Verify this trace exists in DynamoDB and belongs to this experiment
                meta = self._table.get_item(pk=pk, sk=f"{SK_TRACE_PREFIX}{xray_tid}")
                if meta is None:
                    continue
                # Apply non-span post-filters on this item too
                from mlflow_dynamodbstore.dynamodb.search import (
                    _apply_trace_post_filter,
                )

                if not all(
                    _apply_trace_post_filter(self._table, pk, xray_tid, meta, p)
                    for p in non_span_predicates
                ):
                    continue

                seen_trace_ids.add(xray_tid)
                self._cache.put("trace_exp", xray_tid, exp_id)
                trace_info = self._build_trace_info(exp_id, xray_tid, meta)
                traces.append(trace_info)
                remaining -= 1

        if remaining <= 0:
            if inner_next or i < len(experiment_ids) - 1:
                next_page_token = encode_page_token(
                    {
                        "exp_idx": i if inner_next else i + 1,
                        "inner_token": inner_next,
                    }
                )
            break

        if inner_next:
            next_page_token = encode_page_token(
                {
                    "exp_idx": i,
                    "inner_token": inner_next,
                }
            )
            break

    return traces, next_page_token

calculate_trace_filter_correlation

calculate_trace_filter_correlation(experiment_ids, filter_string1, filter_string2, base_filter=None)

Calculate NPMI correlation between two trace filters.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def calculate_trace_filter_correlation(
    self,
    experiment_ids: list[str],
    filter_string1: str,
    filter_string2: str,
    base_filter: str | None = None,
) -> TraceFilterCorrelationResult:
    """Calculate NPMI correlation between two trace filters."""

    from mlflow_dynamodbstore.dynamodb.search import (
        _apply_trace_post_filter,
        execute_trace_query,
        parse_trace_filter,
        plan_trace_query,
    )

    preds1 = parse_trace_filter(filter_string1)
    preds2 = parse_trace_filter(filter_string2)
    base_preds = parse_trace_filter(base_filter)

    # Plan query using base_filter predicates (for efficient index usage)
    plan = plan_trace_query(base_preds, None)

    total_count = 0
    filter1_count = 0
    filter2_count = 0
    joint_count = 0

    for exp_id in experiment_ids:
        pk = f"{PK_EXPERIMENT_PREFIX}{exp_id}"
        page_token: str | None = None

        while True:
            items, page_token = execute_trace_query(
                table=self._table,
                plan=plan,
                pk=pk,
                max_results=1000,
                page_token=page_token,
                predicates=base_preds,
            )

            for item in items:
                trace_id = item["trace_id"]
                total_count += 1

                match1 = all(
                    _apply_trace_post_filter(self._table, pk, trace_id, item, p) for p in preds1
                )
                match2 = all(
                    _apply_trace_post_filter(self._table, pk, trace_id, item, p) for p in preds2
                )

                if match1:
                    filter1_count += 1
                if match2:
                    filter2_count += 1
                if match1 and match2:
                    joint_count += 1

            if not page_token:
                break

    # Compute NPMI using MLflow's standard implementation
    from mlflow.store.analytics.trace_correlation import calculate_npmi_from_counts

    npmi_result = calculate_npmi_from_counts(
        joint_count=joint_count,
        filter1_count=filter1_count,
        filter2_count=filter2_count,
        total_count=total_count,
    )

    return TraceFilterCorrelationResult(
        npmi=npmi_result.npmi,
        npmi_smoothed=npmi_result.npmi_smoothed,
        filter1_count=filter1_count,
        filter2_count=filter2_count,
        joint_count=joint_count,
        total_count=total_count,
    )

set_trace_tag

set_trace_tag(trace_id, key, value)

Set a tag on a trace.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def set_trace_tag(self, trace_id: str, key: str, value: str) -> None:
    """Set a tag on a trace."""
    from mlflow.utils.validation import _validate_length_limit, _validate_tag_name

    _validate_tag_name(key)
    _validate_length_limit("key", 250, key)
    experiment_id = self._resolve_trace_experiment(trace_id)
    pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"
    # Read TTL from the trace META item
    meta = self._table.get_item(pk=pk, sk=f"{SK_TRACE_PREFIX}{trace_id}")
    if meta is None:
        raise MlflowException(
            f"Trace with ID {trace_id} is not found.",
            error_code=RESOURCE_DOES_NOT_EXIST,
        )
    ttl = int(meta["ttl"]) if "ttl" in meta else self._get_trace_ttl()
    self._write_trace_tag(experiment_id, trace_id, key, value, ttl)

delete_trace_tag

delete_trace_tag(trace_id, key)

Delete a tag from a trace.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def delete_trace_tag(self, trace_id: str, key: str) -> None:
    """Delete a tag from a trace."""
    experiment_id = self._resolve_trace_experiment(trace_id)
    pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"
    sk = f"{SK_TRACE_PREFIX}{trace_id}#TAG#{key}"
    self._table.delete_item(pk=pk, sk=sk)
    if self._config.should_denormalize(experiment_id, key):
        self._remove_denormalized_tag(pk, f"{SK_TRACE_PREFIX}{trace_id}", key)
    # Remove FTS items for tag value if FTS was configured
    if self._config.should_trigram("trace_tag_value"):
        self._delete_fts_for_entity_field(pk=pk, entity_type="T", entity_id=trace_id, field=key)
link_traces_to_run(trace_ids, run_id)

Link traces to a run by writing mlflow.sourceRun request metadata.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def link_traces_to_run(self, trace_ids: list[str], run_id: str) -> None:
    """Link traces to a run by writing mlflow.sourceRun request metadata."""
    for trace_id in trace_ids:
        experiment_id = self._resolve_trace_experiment(trace_id)
        pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"
        # Read TTL from the trace META item
        meta = self._table.get_item(pk=pk, sk=f"{SK_TRACE_PREFIX}{trace_id}")
        ttl = int(meta["ttl"]) if meta and "ttl" in meta else self._get_trace_ttl()
        rmeta_item: dict[str, Any] = {
            "PK": pk,
            "SK": f"{SK_TRACE_PREFIX}{trace_id}#RMETA#{TraceMetadataKey.SOURCE_RUN}",
            "key": TraceMetadataKey.SOURCE_RUN,
            "value": run_id,
        }
        if ttl is not None:
            rmeta_item["ttl"] = ttl
        self._table.put_item(rmeta_item)
unlink_traces_from_run(trace_ids, run_id)

Unlink traces from a run by deleting mlflow.sourceRun RMETA items.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def unlink_traces_from_run(self, trace_ids: list[str], run_id: str) -> None:
    """Unlink traces from a run by deleting mlflow.sourceRun RMETA items."""
    for trace_id in trace_ids:
        try:
            experiment_id = self._resolve_trace_experiment(trace_id)
        except MlflowException:
            continue
        pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"
        sk = f"{SK_TRACE_PREFIX}{trace_id}#RMETA#{TraceMetadataKey.SOURCE_RUN}"
        self._table.delete_item(pk=pk, sk=sk)
link_prompts_to_trace(trace_id, prompt_versions)

Link prompt versions to a trace by writing mlflow.promptVersions tag.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def link_prompts_to_trace(self, trace_id: str, prompt_versions: list[PromptVersion]) -> None:
    """Link prompt versions to a trace by writing mlflow.promptVersions tag."""
    import json as _json

    experiment_id = self._resolve_trace_experiment(trace_id)
    pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"
    meta = self._table.get_item(pk=pk, sk=f"{SK_TRACE_PREFIX}{trace_id}")
    if meta is None:
        raise MlflowException(
            f"Trace with ID {trace_id} is not found.",
            error_code=RESOURCE_DOES_NOT_EXIST,
        )
    ttl = int(meta["ttl"]) if "ttl" in meta else self._get_trace_ttl()

    versions_json = _json.dumps(
        [{"name": pv.name, "version": pv.version} for pv in prompt_versions]
    )
    self._write_trace_tag(experiment_id, trace_id, "mlflow.promptVersions", versions_json, ttl)

batch_get_trace_infos

batch_get_trace_infos(trace_ids, location=None)

Get trace metadata for given trace IDs without loading spans.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def batch_get_trace_infos(
    self, trace_ids: list[str], location: str | None = None
) -> list[TraceInfo]:
    """Get trace metadata for given trace IDs without loading spans."""
    if not trace_ids:
        return []

    seen: set[str] = set()
    results: list[TraceInfo] = []

    for trace_id in trace_ids:
        if trace_id in seen:
            continue
        seen.add(trace_id)

        try:
            if location:
                experiment_id = location
            else:
                experiment_id = self._resolve_trace_experiment(trace_id)
        except MlflowException:
            continue  # Skip non-existent traces

        pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"
        meta = self._table.get_item(pk=pk, sk=f"{SK_TRACE_PREFIX}{trace_id}")
        if meta is None:
            continue

        trace_info = self._build_trace_info(experiment_id, trace_id, meta)
        results.append(trace_info)

    return results

batch_get_traces

batch_get_traces(trace_ids, location=None)

Get complete traces with spans for given trace IDs.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def batch_get_traces(self, trace_ids: list[str], location: str | None = None) -> list[Trace]:
    """Get complete traces with spans for given trace IDs."""
    import json as _json

    from mlflow.entities.trace import Trace
    from mlflow.entities.trace_data import TraceData

    from mlflow_dynamodbstore.xray.span_converter import span_dicts_to_mlflow_spans

    if not trace_ids:
        return []

    seen: set[str] = set()
    results: list[Trace] = []

    for trace_id in trace_ids:
        if trace_id in seen:
            continue
        seen.add(trace_id)

        try:
            if location:
                experiment_id = location
            else:
                experiment_id = self._resolve_trace_experiment(trace_id)
        except MlflowException:
            continue

        pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"
        meta = self._table.get_item(pk=pk, sk=f"{SK_TRACE_PREFIX}{trace_id}")
        if meta is None:
            continue

        trace_info = self._build_trace_info(experiment_id, trace_id, meta)

        # Read cached spans
        spans_sk = f"{SK_TRACE_PREFIX}{trace_id}#SPANS"
        cached = self._table.get_item(pk=pk, sk=spans_sk)
        if cached is not None:
            span_dicts = _json.loads(cached["data"])
            # Try V3 format (Span.to_dict) first, fall back to X-Ray format
            if span_dicts and "start_time_unix_nano" in span_dicts[0]:
                try:
                    from mlflow.entities.span import Span as SpanEntity

                    spans = [SpanEntity.from_dict(sd) for sd in span_dicts]
                except MlflowException:
                    spans = span_dicts_to_mlflow_spans(span_dicts, trace_id)
            else:
                # X-Ray converter format
                spans = span_dicts_to_mlflow_spans(span_dicts, trace_id)
        else:
            spans = []

        results.append(Trace(info=trace_info, data=TraceData(spans=spans)))

    return results

find_completed_sessions

find_completed_sessions(experiment_id, min_last_trace_timestamp_ms, max_last_trace_timestamp_ms, max_results=None, filter_string=None)

Find completed sessions by last trace timestamp range via GSI2.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def find_completed_sessions(
    self,
    experiment_id: str,
    min_last_trace_timestamp_ms: int,
    max_last_trace_timestamp_ms: int,
    max_results: int | None = None,
    filter_string: str | None = None,
) -> list[Any]:
    """Find completed sessions by last trace timestamp range via GSI2."""
    from mlflow.genai.scorers.online.entities import CompletedSession

    gsi2pk = f"{GSI2_SESSIONS_PREFIX}{self._workspace}#{experiment_id}"

    items = self._table.query(
        pk=gsi2pk,
        sk_gte=f"{min_last_trace_timestamp_ms:020d}",
        sk_lte=f"{max_last_trace_timestamp_ms:020d}",
        index_name="gsi2",
        scan_forward=True,
    )

    # Optional: post-filter sessions by trace attributes
    if filter_string:
        from mlflow_dynamodbstore.dynamodb.search import (
            _apply_trace_post_filter,
            parse_trace_filter,
        )

        preds = parse_trace_filter(filter_string)
        filtered_items = []
        for item in items:
            session_id = item["session_id"]
            exp_pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"
            trace_items = self._table.query(
                pk=exp_pk,
                sk_prefix=SK_TRACE_PREFIX,
            )
            session_qualifies = False
            for t_item in trace_items:
                if "trace_id" not in t_item:
                    continue
                tid = t_item["trace_id"]
                rmeta_sk = f"{SK_TRACE_PREFIX}{tid}#RMETA#mlflow.traceSession"
                rmeta = self._table.get_item(pk=exp_pk, sk=rmeta_sk)
                if rmeta and rmeta.get("value") == session_id:
                    if all(
                        _apply_trace_post_filter(self._table, exp_pk, tid, t_item, p)
                        for p in preds
                    ):
                        session_qualifies = True
                        break
            if session_qualifies:
                filtered_items.append(item)
        items = filtered_items

    results: list[CompletedSession] = []
    for item in items:
        session = CompletedSession(
            session_id=item["session_id"],
            first_trace_timestamp_ms=int(item["first_trace_timestamp_ms"]),
            last_trace_timestamp_ms=int(item["last_trace_timestamp_ms"]),
        )
        results.append(session)

    if max_results is not None:
        results = results[:max_results]

    return results

create_assessment

create_assessment(assessment)

Create a new assessment for a trace.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def create_assessment(self, assessment: Assessment) -> Assessment:
    """Create a new assessment for a trace."""
    trace_id = assessment.trace_id
    if trace_id is None:
        raise MlflowException(
            "Assessment must have a trace_id.",
            error_code=RESOURCE_DOES_NOT_EXIST,
        )
    experiment_id = self._resolve_trace_experiment(trace_id)
    pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"

    # Read TTL from the trace META item
    meta = self._table.get_item(pk=pk, sk=f"{SK_TRACE_PREFIX}{trace_id}")
    if meta is None:
        raise MlflowException(
            f"Trace with ID {trace_id} is not found.",
            error_code=RESOURCE_DOES_NOT_EXIST,
        )
    ttl = int(meta["ttl"]) if "ttl" in meta else self._get_trace_ttl()

    # Generate assessment ID
    assessment_id = generate_ulid()
    now_ms = int(time.time() * 1000)

    # Build the assessment item storing the full serialized assessment dict
    assess_dict = assessment.to_dictionary()
    assess_dict["assessment_id"] = assessment_id
    assess_dict["create_time"] = assess_dict.get(
        "create_time"
    ) or milliseconds_to_proto_timestamp(now_ms)
    assess_dict["last_update_time"] = assess_dict.get(
        "last_update_time"
    ) or milliseconds_to_proto_timestamp(now_ms)

    sk = f"{SK_TRACE_PREFIX}{trace_id}#ASSESS#{assessment_id}"
    item: dict[str, Any] = {
        "PK": pk,
        "SK": sk,
        "data": assess_dict,
    }
    if ttl is not None:
        item["ttl"] = ttl
    self._denormalize_assessment_item(item, assess_dict)
    self._table.put_item(item)

    # Write FTS items for the assessment value text
    fts_text = self._assessment_fts_text(assessment)
    if fts_text:
        self._write_assessment_fts(pk, trace_id, assessment_id, fts_text, ttl)

    # Return the assessment with the generated ID
    result = Assessment.from_dictionary(assess_dict)
    return result

get_assessment

get_assessment(trace_id, assessment_id)

Fetch an assessment by trace ID and assessment ID.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def get_assessment(self, trace_id: str, assessment_id: str) -> Assessment:
    """Fetch an assessment by trace ID and assessment ID."""
    experiment_id = self._resolve_trace_experiment(trace_id)
    pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"
    sk = f"{SK_TRACE_PREFIX}{trace_id}#ASSESS#{assessment_id}"

    item = self._table.get_item(pk=pk, sk=sk)
    if item is None:
        raise MlflowException(
            f"Assessment '{assessment_id}' for trace '{trace_id}' not found.",
            error_code=RESOURCE_DOES_NOT_EXIST,
        )

    return Assessment.from_dictionary(item["data"])

update_assessment

update_assessment(trace_id, assessment_id, name=None, expectation=None, feedback=None, rationale=None, metadata=None)

Update mutable fields of an assessment.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def update_assessment(
    self,
    trace_id: str,
    assessment_id: str,
    name: str | None = None,
    expectation: str | None = None,
    feedback: str | None = None,
    rationale: str | None = None,
    metadata: dict[str, str] | None = None,
) -> Assessment:
    """Update mutable fields of an assessment."""
    experiment_id = self._resolve_trace_experiment(trace_id)
    pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"
    sk = f"{SK_TRACE_PREFIX}{trace_id}#ASSESS#{assessment_id}"

    item = self._table.get_item(pk=pk, sk=sk)
    if item is None:
        raise MlflowException(
            f"Assessment '{assessment_id}' for trace '{trace_id}' not found.",
            error_code=RESOURCE_DOES_NOT_EXIST,
        )
    ttl = int(item["ttl"]) if "ttl" in item else self._get_trace_ttl()

    assess_dict = item["data"]

    # Capture old text for FTS diff
    old_assessment = Assessment.from_dictionary(assess_dict)
    old_fts_text = self._assessment_fts_text(old_assessment)

    # Apply updates
    now_ms = int(time.time() * 1000)
    assess_dict["last_update_time"] = milliseconds_to_proto_timestamp(now_ms)

    if name is not None:
        assess_dict["assessment_name"] = name
    if feedback is not None:
        assess_dict["feedback"] = {"value": feedback}
        assess_dict.pop("expectation", None)
    if expectation is not None:
        assess_dict["expectation"] = {"value": expectation}
        assess_dict.pop("feedback", None)
    if rationale is not None:
        assess_dict["rationale"] = rationale
    if metadata is not None:
        assess_dict["metadata"] = metadata

    # Write updated item
    item["data"] = assess_dict
    self._denormalize_assessment_item(item, assess_dict)
    self._table.put_item(item)

    # FTS diff
    updated_assessment = Assessment.from_dictionary(assess_dict)
    new_fts_text = self._assessment_fts_text(updated_assessment)
    field = f"assess_{assessment_id}"

    if old_fts_text or new_fts_text:
        levels = ("W", "3", "2")
        tokens_to_add, tokens_to_remove = fts_diff(old_fts_text, new_fts_text or "", levels)
        field_suffix = f"#{field}"
        entity_prefix = f"T#{trace_id}{field_suffix}"

        # Delete removed FTS items
        if tokens_to_remove:
            rev_prefix = f"{SK_FTS_REV_PREFIX}{entity_prefix}#"
            rev_items = self._table.query(pk=pk, sk_prefix=rev_prefix)
            for rev_item in rev_items:
                rev_sk = rev_item["SK"]
                suffix = rev_sk[len(SK_FTS_REV_PREFIX) + len(entity_prefix) + 1 :]
                parts = suffix.split("#", 1)
                if len(parts) == 2:
                    lvl, tok = parts[0], parts[1]
                    if (lvl, tok) in tokens_to_remove:
                        forward_sk = f"{SK_FTS_PREFIX}{lvl}#T#{tok}#{trace_id}{field_suffix}"
                        self._table.delete_item(pk=pk, sk=forward_sk)
                        self._table.delete_item(pk=pk, sk=rev_sk)

        # Write new FTS items
        if tokens_to_add:
            new_fts_items: list[dict[str, Any]] = []
            for lvl, tok in tokens_to_add:
                forward_sk = f"{SK_FTS_PREFIX}{lvl}#T#{tok}#{trace_id}{field_suffix}"
                reverse_sk = f"{SK_FTS_REV_PREFIX}{entity_prefix}#{lvl}#{tok}"
                new_fwd: dict[str, Any] = {"PK": pk, "SK": forward_sk}
                new_rev: dict[str, Any] = {"PK": pk, "SK": reverse_sk}
                if ttl is not None:
                    new_fwd["ttl"] = ttl
                    new_rev["ttl"] = ttl
                new_fts_items.append(new_fwd)
                new_fts_items.append(new_rev)
            self._table.batch_write(new_fts_items)

    return updated_assessment

delete_assessment

delete_assessment(trace_id, assessment_id)

Delete an assessment and clean up its FTS items.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def delete_assessment(self, trace_id: str, assessment_id: str) -> None:
    """Delete an assessment and clean up its FTS items."""
    experiment_id = self._resolve_trace_experiment(trace_id)
    pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"
    sk = f"{SK_TRACE_PREFIX}{trace_id}#ASSESS#{assessment_id}"

    item = self._table.get_item(pk=pk, sk=sk)
    if item is None:
        raise MlflowException(
            f"Assessment '{assessment_id}' for trace '{trace_id}' not found.",
            error_code=RESOURCE_DOES_NOT_EXIST,
        )

    # Delete assessment item
    self._table.delete_item(pk=pk, sk=sk)

    # Clean up FTS items via reverse index
    field = f"assess_{assessment_id}"
    self._delete_fts_for_entity_field(pk=pk, entity_type="T", entity_id=trace_id, field=field)

delete_traces

delete_traces(experiment_id, max_timestamp_millis=None, max_traces=None, trace_ids=None)

Delete traces and all their sub-items (tags, metadata, assessments, FTS).

Source code in src/mlflow_dynamodbstore/tracking_store.py
def delete_traces(
    self,
    experiment_id: str,
    max_timestamp_millis: int | None = None,
    max_traces: int | None = None,
    trace_ids: list[str] | None = None,
) -> int:
    """Delete traces and all their sub-items (tags, metadata, assessments, FTS)."""
    if not trace_ids:
        return 0

    pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"
    deleted = 0

    for trace_id in trace_ids:
        # 1. Query all trace sub-items: SK begins_with T#<trace_id>
        trace_prefix = f"{SK_TRACE_PREFIX}{trace_id}"
        trace_items = self._table.query(pk=pk, sk_prefix=trace_prefix)

        # 2. Query FTS_REV items for this trace to find forward FTS items
        fts_rev_prefix = f"{SK_FTS_REV_PREFIX}T#{trace_id}"
        fts_rev_items = self._table.query(pk=pk, sk_prefix=fts_rev_prefix)

        # 3. Derive forward FTS SKs from reverse items and delete them
        for rev_item in fts_rev_items:
            rev_sk = rev_item["SK"]
            # Everything after "FTS_REV#"
            after_rev_prefix = rev_sk[len(SK_FTS_REV_PREFIX) :]
            # after_rev_prefix: T#<trace_id>[#<field>]#<level>#<token>
            # We need to split into entity_prefix and level#token.
            # entity_prefix always starts with T#<trace_id>.
            # After T#<trace_id>, there may be #<field> or directly #<level>#<token>.
            # Level is always a single character (W or 3).
            # So we look for the pattern where after the entity part,
            # we have #<single_char>#<rest> where single_char is the level.
            base = f"T#{trace_id}"
            rest = after_rev_prefix[len(base) :]
            # rest starts with '#'
            # Split: could be #<field>#<level>#<token> or #<level>#<token>
            # Level chars: W, 3
            parts = rest.lstrip("#").split("#")
            # Try to find the level marker
            # If parts[0] is a known level (W or 3), then no field
            # Otherwise parts[0] is field, parts[1] is level
            if parts[0] in ("W", "3", "2"):
                field_part = ""
                lvl = parts[0]
                tok = "#".join(parts[1:])
            else:
                field_part = f"#{parts[0]}"
                lvl = parts[1]
                tok = "#".join(parts[2:])

            forward_sk = f"{SK_FTS_PREFIX}{lvl}#T#{tok}#{trace_id}{field_part}"
            self._table.delete_item(pk=pk, sk=forward_sk)
            self._table.delete_item(pk=pk, sk=rev_sk)

        # 4. Delete all trace sub-items (META, tags, RMETA, assessments, CLIENTPTR)
        for item in trace_items:
            self._table.delete_item(pk=pk, sk=item["SK"])

        deleted += 1

    return deleted

create_dataset

create_dataset(name, tags=None, experiment_ids=None)

Create a new evaluation dataset and return it.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def create_dataset(
    self,
    name: str,
    tags: dict[str, str] | None = None,
    experiment_ids: list[str] | None = None,
) -> EvaluationDataset:
    """Create a new evaluation dataset and return it."""
    # Check name uniqueness via GSI3
    existing = self._table.query(
        pk=f"{GSI3_DS_NAME_PREFIX}{self._workspace}#{name.lower()}",
        index_name="gsi3",
        limit=1,
    )
    if existing:
        raise MlflowException(
            f"Dataset with name '{name}' already exists.",
            error_code=RESOURCE_ALREADY_EXISTS,
        )

    now_ms = int(time.time() * 1000)
    dataset_id = f"d-{generate_ulid()}"
    digest = self._compute_dataset_digest(name, now_ms)
    pk = f"{PK_DATASET_PREFIX}{dataset_id}"

    # Write META item
    meta_item: dict[str, Any] = {
        "PK": pk,
        "SK": SK_DATASET_META,
        "name": name,
        "digest": digest,
        "created_time": now_ms,
        "last_update_time": now_ms,
        "workspace": self._workspace,
        "tags": tags or {},
        # LSI projections
        LSI1_SK: f"{now_ms:020d}",
        LSI2_SK: now_ms,
        LSI3_SK: name.lower(),
        # GSI2: list all datasets in workspace
        GSI2_PK: f"{GSI2_DS_LIST_PREFIX}{self._workspace}",
        GSI2_SK: dataset_id,
        # GSI3: name uniqueness
        GSI3_PK: f"{GSI3_DS_NAME_PREFIX}{self._workspace}#{name.lower()}",
        GSI3_SK: dataset_id,
    }
    self._table.put_item(meta_item)

    # Write tag items
    if tags:
        for key, value in tags.items():
            tag_item: dict[str, Any] = {
                "PK": pk,
                "SK": f"{SK_DATASET_TAG_PREFIX}{key}",
                "key": key,
                "value": value,
            }
            self._table.put_item(tag_item)

    # Write experiment link items
    if experiment_ids:
        for exp_id in experiment_ids:
            exp_link_item: dict[str, Any] = {
                "PK": pk,
                "SK": f"{SK_DATASET_EXP_PREFIX}{exp_id}",
                # GSI1 reverse lookup: from experiment -> datasets
                GSI1_PK: f"{GSI1_DS_EXP_PREFIX}{exp_id}",
                GSI1_SK: dataset_id,
            }
            self._table.put_item(exp_link_item)

    return EvaluationDataset(
        dataset_id=dataset_id,
        name=name,
        digest=digest,
        created_time=now_ms,
        last_update_time=now_ms,
        tags=tags or {},
    )

get_dataset

get_dataset(dataset_id)

Fetch an evaluation dataset by ID.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def get_dataset(self, dataset_id: str) -> EvaluationDataset:
    """Fetch an evaluation dataset by ID."""
    pk = f"{PK_DATASET_PREFIX}{dataset_id}"
    meta = self._table.get_item(pk=pk, sk=SK_DATASET_META)
    if meta is None:
        raise MlflowException(
            f"Dataset '{dataset_id}' does not exist.",
            error_code=RESOURCE_DOES_NOT_EXIST,
        )

    # Load tags from tag items
    tag_items = self._table.query(pk=pk, sk_prefix=SK_DATASET_TAG_PREFIX)
    tags = {item["key"]: item["value"] for item in tag_items}

    # Load experiment IDs
    experiment_ids = self.get_dataset_experiment_ids(dataset_id)

    ds = EvaluationDataset(
        dataset_id=dataset_id,
        name=meta["name"],
        digest=meta["digest"],
        created_time=int(meta["created_time"]),
        last_update_time=int(meta["last_update_time"]),
        tags=tags,
        profile=meta.get("profile"),
    )
    ds.experiment_ids = experiment_ids
    return ds

delete_dataset

delete_dataset(dataset_id)

Delete an evaluation dataset and all its sub-items.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def delete_dataset(self, dataset_id: str) -> None:
    """Delete an evaluation dataset and all its sub-items."""
    pk = f"{PK_DATASET_PREFIX}{dataset_id}"
    all_items = self._table.query(pk=pk)
    if not all_items:
        # Idempotent: nothing to delete
        return
    keys = [{"PK": item["PK"], "SK": item["SK"]} for item in all_items]
    self._table.batch_delete(keys)

get_dataset_experiment_ids

get_dataset_experiment_ids(dataset_id)

Return experiment IDs associated with a dataset.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def get_dataset_experiment_ids(self, dataset_id: str) -> list[str]:
    """Return experiment IDs associated with a dataset."""
    pk = f"{PK_DATASET_PREFIX}{dataset_id}"
    items = self._table.query(pk=pk, sk_prefix=SK_DATASET_EXP_PREFIX)
    return [item["SK"][len(SK_DATASET_EXP_PREFIX) :] for item in items]

search_datasets

search_datasets(experiment_ids=None, filter_string=None, max_results=1000, order_by=None, page_token=None)

Search evaluation datasets with optional filtering, ordering, and pagination.

Parameters:

Name Type Description Default
experiment_ids list[str] | None

Limit results to datasets linked to these experiments.

None
filter_string str | None

Filter expression, supports name LIKE 'pattern'.

None
max_results int

Maximum number of results per page.

1000
order_by list[str] | None

Ordering criteria, e.g. ["name ASC"].

None
page_token str | None

Opaque token for fetching the next page.

None

Returns:

Type Description
PagedList[EvaluationDataset]

A PagedList of EvaluationDataset objects.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def search_datasets(
    self,
    experiment_ids: list[str] | None = None,
    filter_string: str | None = None,
    max_results: int = 1000,
    order_by: list[str] | None = None,
    page_token: str | None = None,
) -> PagedList[EvaluationDataset]:
    """Search evaluation datasets with optional filtering, ordering, and pagination.

    Args:
        experiment_ids: Limit results to datasets linked to these experiments.
        filter_string: Filter expression, supports ``name LIKE 'pattern'``.
        max_results: Maximum number of results per page.
        order_by: Ordering criteria, e.g. ``["name ASC"]``.
        page_token: Opaque token for fetching the next page.

    Returns:
        A PagedList of EvaluationDataset objects.
    """
    import re

    from mlflow_dynamodbstore.dynamodb.pagination import (
        decode_page_token,
        encode_page_token,
    )

    # Decode pagination state
    token_state = decode_page_token(page_token)
    offset = token_state["offset"] if token_state else 0

    # --- Collect all matching dataset META items ---
    if experiment_ids:
        # AP5 + AP13: query GSI1 for each experiment, then get META items
        dataset_ids: list[str] = []
        seen: set[str] = set()
        for exp_id in experiment_ids:
            gsi1_pk = f"{GSI1_DS_EXP_PREFIX}{exp_id}"
            link_items = self._table.query(pk=gsi1_pk, index_name="gsi1")
            for item in link_items:
                did = item[GSI1_SK]
                if did not in seen:
                    seen.add(did)
                    dataset_ids.append(did)

        all_datasets: list[EvaluationDataset] = []
        for did in dataset_ids:
            pk = f"{PK_DATASET_PREFIX}{did}"
            meta = self._table.get_item(pk=pk, sk=SK_DATASET_META)
            if meta is None:
                continue
            all_datasets.append(
                EvaluationDataset(
                    dataset_id=did,
                    name=meta["name"],
                    digest=meta["digest"],
                    created_time=int(meta["created_time"]),
                    last_update_time=int(meta["last_update_time"]),
                    tags=meta.get("tags") or {},
                )
            )
    else:
        # AP4: query GSI2 for all datasets in workspace
        gsi2_pk = f"{GSI2_DS_LIST_PREFIX}{self._workspace}"
        meta_items = self._table.query(pk=gsi2_pk, index_name="gsi2")
        all_datasets = []
        for item in meta_items:
            did = item[GSI2_SK]
            pk = f"{PK_DATASET_PREFIX}{did}"
            meta = self._table.get_item(pk=pk, sk=SK_DATASET_META)
            if meta is None:
                continue
            all_datasets.append(
                EvaluationDataset(
                    dataset_id=did,
                    name=meta["name"],
                    digest=meta["digest"],
                    created_time=int(meta["created_time"]),
                    last_update_time=int(meta["last_update_time"]),
                    tags=meta.get("tags") or {},
                )
            )

    # --- Apply filter_string in-memory ---
    if filter_string:
        # Support: name LIKE 'pattern%' or name LIKE '%pattern%'
        like_match = re.match(
            r"""name\s+LIKE\s+['"](.+)['"]\s*$""", filter_string.strip(), re.IGNORECASE
        )
        if like_match:
            pattern = like_match.group(1)
            if pattern.startswith("%") and pattern.endswith("%"):
                substring = pattern.strip("%")
                all_datasets = [d for d in all_datasets if substring in d.name]
            elif pattern.endswith("%"):
                prefix = pattern.rstrip("%")
                all_datasets = [d for d in all_datasets if d.name.startswith(prefix)]
            elif pattern.startswith("%"):
                suffix = pattern.lstrip("%")
                all_datasets = [d for d in all_datasets if d.name.endswith(suffix)]
            else:
                all_datasets = [d for d in all_datasets if d.name == pattern]

    # --- Apply order_by in-memory ---
    if order_by:
        for criterion in reversed(order_by):
            parts = criterion.strip().split()
            field = parts[0].lower()
            direction = parts[1].upper() if len(parts) > 1 else "ASC"
            reverse = direction == "DESC"
            if field == "name":
                all_datasets.sort(key=lambda d: d.name, reverse=reverse)
            elif field == "created_time":
                all_datasets.sort(key=lambda d: d.created_time, reverse=reverse)
            elif field == "last_update_time":
                all_datasets.sort(key=lambda d: d.last_update_time, reverse=reverse)

    # --- Apply pagination ---
    total = len(all_datasets)
    page = all_datasets[offset : offset + max_results]
    next_offset = offset + len(page)
    next_token = encode_page_token({"offset": next_offset}) if next_offset < total else None

    return PagedList(page, next_token)

set_dataset_tags

set_dataset_tags(dataset_id, tags)

Set (upsert) one or more tags on an evaluation dataset.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def set_dataset_tags(self, dataset_id: str, tags: dict[str, str]) -> None:
    """Set (upsert) one or more tags on an evaluation dataset."""
    pk = f"{PK_DATASET_PREFIX}{dataset_id}"
    for key, value in tags.items():
        # Write individual tag item (overwrite = upsert)
        self._table.put_item(
            {
                "PK": pk,
                "SK": f"{SK_DATASET_TAG_PREFIX}{key}",
                "key": key,
                "value": value,
            }
        )
        # Update denormalized tags map on META item
        self._denormalize_tag(pk=pk, sk=SK_DATASET_META, tag_key=key, tag_value=value)

    # Update last_update_time and digest
    now_ms = int(time.time() * 1000)
    meta = self._table.get_item(pk=pk, sk=SK_DATASET_META)
    if meta is not None:
        digest = self._compute_dataset_digest(meta["name"], now_ms)
        self._table.update_item(
            pk=pk,
            sk=SK_DATASET_META,
            updates={"last_update_time": now_ms, "digest": digest},
        )

delete_dataset_tag

delete_dataset_tag(dataset_id, key)

Delete a single tag from an evaluation dataset.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def delete_dataset_tag(self, dataset_id: str, key: str) -> None:
    """Delete a single tag from an evaluation dataset."""
    pk = f"{PK_DATASET_PREFIX}{dataset_id}"
    # Delete the individual tag item
    self._table.delete_item(pk=pk, sk=f"{SK_DATASET_TAG_PREFIX}{key}")
    # Remove from denormalized tags map on META
    self._remove_denormalized_tag(pk=pk, sk=SK_DATASET_META, tag_key=key)
    # Update last_update_time and digest
    now_ms = int(time.time() * 1000)
    meta = self._table.get_item(pk=pk, sk=SK_DATASET_META)
    if meta is not None:
        digest = self._compute_dataset_digest(meta["name"], now_ms)
        self._table.update_item(
            pk=pk,
            sk=SK_DATASET_META,
            updates={"last_update_time": now_ms, "digest": digest},
        )

add_dataset_to_experiments

add_dataset_to_experiments(dataset_id, experiment_ids)

Associate an evaluation dataset with one or more experiments.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def add_dataset_to_experiments(
    self, dataset_id: str, experiment_ids: list[str]
) -> EvaluationDataset:
    """Associate an evaluation dataset with one or more experiments."""
    pk = f"{PK_DATASET_PREFIX}{dataset_id}"
    for exp_id in experiment_ids:
        # put_item overwrites = idempotent
        self._table.put_item(
            {
                "PK": pk,
                "SK": f"{SK_DATASET_EXP_PREFIX}{exp_id}",
                GSI1_PK: f"{GSI1_DS_EXP_PREFIX}{exp_id}",
                GSI1_SK: dataset_id,
            }
        )
    # Update last_update_time and digest
    now_ms = int(time.time() * 1000)
    meta = self._table.get_item(pk=pk, sk=SK_DATASET_META)
    if meta is not None:
        digest = self._compute_dataset_digest(meta["name"], now_ms)
        self._table.update_item(
            pk=pk,
            sk=SK_DATASET_META,
            updates={"last_update_time": now_ms, "digest": digest},
        )
    return self.get_dataset(dataset_id)

remove_dataset_from_experiments

remove_dataset_from_experiments(dataset_id, experiment_ids)

Remove an evaluation dataset's association from one or more experiments.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def remove_dataset_from_experiments(
    self, dataset_id: str, experiment_ids: list[str]
) -> EvaluationDataset:
    """Remove an evaluation dataset's association from one or more experiments."""
    pk = f"{PK_DATASET_PREFIX}{dataset_id}"
    for exp_id in experiment_ids:
        self._table.delete_item(pk=pk, sk=f"{SK_DATASET_EXP_PREFIX}{exp_id}")
    # Update last_update_time and digest
    now_ms = int(time.time() * 1000)
    meta = self._table.get_item(pk=pk, sk=SK_DATASET_META)
    if meta is not None:
        digest = self._compute_dataset_digest(meta["name"], now_ms)
        self._table.update_item(
            pk=pk,
            sk=SK_DATASET_META,
            updates={"last_update_time": now_ms, "digest": digest},
        )
    return self.get_dataset(dataset_id)

upsert_dataset_records

upsert_dataset_records(dataset_id, records)

Upsert records into an evaluation dataset.

Each record is keyed by its input_hash (SHA-256 of sorted JSON inputs). Existing records with the same input_hash are updated; new ones are inserted with a fresh edrec_<ulid> ID.

Parameters:

Name Type Description Default
dataset_id str

The ID of the dataset to update.

required
records list[dict[str, Any]]

List of record dicts with keys: inputs, outputs, expectations, tags, source.

required

Returns:

Type Description
dict[str, int]

Dictionary with inserted and updated counts.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def upsert_dataset_records(
    self,
    dataset_id: str,
    records: list[dict[str, Any]],
) -> dict[str, int]:
    """Upsert records into an evaluation dataset.

    Each record is keyed by its input_hash (SHA-256 of sorted JSON inputs).
    Existing records with the same input_hash are updated; new ones are
    inserted with a fresh ``edrec_<ulid>`` ID.

    Args:
        dataset_id: The ID of the dataset to update.
        records: List of record dicts with keys: inputs, outputs,
            expectations, tags, source.

    Returns:
        Dictionary with ``inserted`` and ``updated`` counts.
    """
    import json as _json

    from boto3.dynamodb.conditions import Attr

    # Verify dataset exists
    pk = f"{PK_DATASET_PREFIX}{dataset_id}"
    meta = self._table.get_item(pk=pk, sk=SK_DATASET_META)
    if meta is None:
        raise MlflowException(
            f"Dataset '{dataset_id}' does not exist.",
            error_code=RESOURCE_DOES_NOT_EXIST,
        )

    inserted = 0
    updated = 0

    for record in records:
        inputs = record.get("inputs", {})
        outputs = record.get("outputs")
        expectations = record.get("expectations")
        record_tags = record.get("tags")
        source = record.get("source")

        # Compute input_hash
        input_hash = hashlib.sha256(
            _json.dumps(inputs, sort_keys=True, separators=(",", ":")).encode()
        ).hexdigest()[:8]

        # Dedup: query LSI3 for existing record with same input_hash
        existing_items, _ = self._table.query_page(
            pk=pk,
            sk_prefix=input_hash,
            index_name="lsi3",
            filter_expression=Attr("SK").begins_with(SK_DATASET_RECORD_PREFIX),
        )

        now_ms = int(time.time() * 1000)

        if existing_items:
            # Update the first matching record
            existing = existing_items[0]
            updates: dict[str, Any] = {
                "last_update_time": now_ms,
                LSI2_SK: now_ms,
            }
            if outputs is not None:
                updates["outputs"] = outputs
            if expectations is not None:
                updates["expectations"] = expectations
            if record_tags is not None:
                updates["tags"] = record_tags
            if source is not None:
                updates["source"] = source
            self._table.update_item(
                pk=existing["PK"],
                sk=existing["SK"],
                updates=updates,
            )
            updated += 1
        else:
            # Insert new record
            record_id = f"edrec_{generate_ulid()}"
            record_item: dict[str, Any] = {
                "PK": pk,
                "SK": f"{SK_DATASET_RECORD_PREFIX}{record_id}",
                "dataset_id": dataset_id,
                "dataset_record_id": record_id,
                "inputs": inputs,
                "input_hash": input_hash,
                "created_time": now_ms,
                "last_update_time": now_ms,
                # LSI projections
                LSI1_SK: f"{now_ms:020d}",
                LSI2_SK: now_ms,
                LSI3_SK: input_hash,
            }
            if outputs is not None:
                record_item["outputs"] = outputs
            if expectations is not None:
                record_item["expectations"] = expectations
            if record_tags is not None:
                record_item["tags"] = record_tags
            if source is not None:
                record_item["source"] = source
            self._table.put_item(record_item)
            inserted += 1

    # Recount records to update META profile
    all_records, _ = self._table.query_page(
        pk=pk,
        sk_prefix=SK_DATASET_RECORD_PREFIX,
        limit=10000,
    )
    num_records = len(all_records)

    now_ms = int(time.time() * 1000)
    digest = self._compute_dataset_digest(meta["name"], now_ms)
    self._table.update_item(
        pk=pk,
        sk=SK_DATASET_META,
        updates={
            "profile": _json.dumps({"num_records": num_records}),
            "last_update_time": now_ms,
            "digest": digest,
            LSI2_SK: now_ms,
        },
    )

    return {"inserted": inserted, "updated": updated}

delete_dataset_records

delete_dataset_records(dataset_id, record_ids)

Delete specific records from a dataset.

Parameters:

Name Type Description Default
dataset_id str

The dataset to delete records from.

required
record_ids list[str]

List of record IDs (edrec_...) to delete.

required

Returns:

Type Description
int

Count of records deleted.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def delete_dataset_records(
    self,
    dataset_id: str,
    record_ids: list[str],
) -> int:
    """Delete specific records from a dataset.

    Args:
        dataset_id: The dataset to delete records from.
        record_ids: List of record IDs (``edrec_...``) to delete.

    Returns:
        Count of records deleted.
    """
    import json as _json

    pk = f"{PK_DATASET_PREFIX}{dataset_id}"
    keys = [{"PK": pk, "SK": f"{SK_DATASET_RECORD_PREFIX}{rec_id}"} for rec_id in record_ids]
    self._table.batch_delete(keys)

    # Recount and update META profile
    meta = self._table.get_item(pk=pk, sk=SK_DATASET_META)
    if meta is not None:
        all_records, _ = self._table.query_page(
            pk=pk,
            sk_prefix=SK_DATASET_RECORD_PREFIX,
            limit=10000,
        )
        num_records = len(all_records)
        now_ms = int(time.time() * 1000)
        digest = self._compute_dataset_digest(meta["name"], now_ms)
        self._table.update_item(
            pk=pk,
            sk=SK_DATASET_META,
            updates={
                "profile": _json.dumps({"num_records": num_records}),
                "last_update_time": now_ms,
                "digest": digest,
                LSI2_SK: now_ms,
            },
        )

    return len(record_ids)

query_trace_metrics

query_trace_metrics(experiment_ids, view_type, metric_name, aggregations, dimensions=None, filters=None, time_interval_seconds=None, start_time_ms=None, end_time_ms=None, max_results=1000, page_token=None)

Query aggregated trace metrics across experiments.

Source code in src/mlflow_dynamodbstore/tracking_store.py
def query_trace_metrics(
    self,
    experiment_ids: list[str],
    view_type: MetricViewType,
    metric_name: str,
    aggregations: list[MetricAggregation],
    dimensions: list[str] | None = None,
    filters: list[str] | None = None,
    time_interval_seconds: int | None = None,
    start_time_ms: int | None = None,
    end_time_ms: int | None = None,
    max_results: int = 1000,
    page_token: str | None = None,
) -> PagedList[list[MetricDataPoint]]:
    """Query aggregated trace metrics across experiments."""
    from mlflow.entities.trace_metrics import AggregationType, MetricViewType
    from mlflow.store.tracking.utils.sql_trace_metrics_utils import (
        validate_query_trace_metrics_params,
    )

    from mlflow_dynamodbstore.trace_metrics.accumulators import MetricAccumulator
    from mlflow_dynamodbstore.trace_metrics.extractors import (
        TIME_BUCKET_LABEL,
        build_dimension_key,
        compute_time_bucket,
        extract_metric_value,
        get_timestamp_for_view,
    )
    from mlflow_dynamodbstore.trace_metrics.filters import (
        apply_trace_metric_filters,
        filter_assessment_items,
        filter_span_items,
        meta_prefilter_spans,
    )
    from mlflow_dynamodbstore.trace_metrics.pagination import (
        cache_get,
        cache_put,
        compute_query_hash,
        decode_page_token,
        encode_page_token,
    )

    # 1. VALIDATE
    validate_query_trace_metrics_params(view_type, metric_name, aggregations, dimensions)
    if time_interval_seconds is not None and (start_time_ms is None or end_time_ms is None):
        raise MlflowException(
            "start_time_ms and end_time_ms are required when time_interval_seconds is set.",
            error_code=INVALID_PARAMETER_VALUE,
        )

    # 2. CHECK CACHE (if page_token)
    if page_token:
        token_data = decode_page_token(page_token)
        cached = cache_get(self._table, token_data["query_hash"])
        if cached is not None:
            offset = token_data["offset"]
            page = cached[offset : offset + max_results]
            next_token = None
            if offset + max_results < len(cached):
                next_token = encode_page_token(token_data["query_hash"], offset + max_results)
            return PagedList(page, next_token)  # type: ignore[arg-type]

    # Need percentile values?
    need_values = any(a.aggregation_type == AggregationType.PERCENTILE for a in aggregations)

    # 3. STREAM TRACE META ITEMS and accumulate
    accumulators: dict[tuple[str | None, ...], MetricAccumulator] = {}
    dim_labels: dict[tuple[str | None, ...], dict[str, str]] = {}

    for experiment_id in experiment_ids:
        pk = f"{PK_EXPERIMENT_PREFIX}{experiment_id}"

        # Query META items (with optional time range via LSI1)
        if start_time_ms is not None and end_time_ms is not None:
            meta_candidates = self._table.query(
                pk=pk,
                index_name="lsi1",
                sk_gte=f"{start_time_ms:020d}",
                sk_lte=f"{end_time_ms:020d}",
            )
        else:
            # NOTE: Without time range, this fetches all T# items (including sub-items).
            # The "request_time" in item filter discards non-META items in Python.
            # This is known technical debt — a dedicated META-only
            # index would be more efficient.
            meta_candidates = self._table.query(pk=pk, sk_prefix=SK_TRACE_PREFIX)

        # Filter to META items only (have "request_time" attribute)
        meta_items = [item for item in meta_candidates if "request_time" in item]

        for meta_item in meta_items:
            trace_id = meta_item["SK"][len(SK_TRACE_PREFIX) :]

            # Only query tags if needed for filters or trace_name dimension
            needs_tags = False
            if filters:
                needs_tags = any("tag" in f for f in filters)
            if dimensions and "trace_name" in dimensions:
                needs_tags = True
            needs_metadata = bool(filters and any("metadata" in f for f in filters))

            tag_items = (
                self._table.query(pk=pk, sk_prefix=f"{SK_TRACE_PREFIX}{trace_id}#TAG#")
                if needs_tags
                else []
            )
            metadata_items = (
                self._table.query(pk=pk, sk_prefix=f"{SK_TRACE_PREFIX}{trace_id}#RMETA#")
                if needs_metadata
                else []
            )

            # Apply trace-level filters
            if not apply_trace_metric_filters(
                meta_item, filters, view_type, tag_items, metadata_items
            ):
                continue

            # Build trace_tags dict for dimension extraction
            trace_tags = {t["key"]: t["value"] for t in tag_items}

            if view_type == MetricViewType.TRACES:
                # Fetch TMETRIC items for token metrics
                tmetric_items = self._table.query(
                    pk=pk, sk_prefix=f"{SK_TRACE_PREFIX}{trace_id}{SK_TRACE_METRIC_PREFIX}"
                )

                # Compute time bucket
                time_bucket = None
                if time_interval_seconds is not None:
                    ts = get_timestamp_for_view(view_type, meta_item, meta_item)
                    time_bucket = compute_time_bucket(ts, time_interval_seconds)

                value = extract_metric_value(
                    metric_name, view_type, meta_item, meta_item, tmetric_items
                )
                if value is None:
                    continue

                dim_key = build_dimension_key(
                    dimensions, view_type, meta_item, meta_item, trace_tags, time_bucket
                )
                # Skip if any dimension value is None
                if dimensions and any(v is None for v in dim_key):
                    continue

                if dim_key not in accumulators:
                    accumulators[dim_key] = MetricAccumulator(collect_values=need_values)
                    # Build dimension labels
                    labels: dict[str, str] = {}
                    idx = 0
                    if time_bucket is not None:
                        labels[TIME_BUCKET_LABEL] = time_bucket
                        idx = 1
                    for d in dimensions or []:
                        labels[d] = str(dim_key[idx]) if dim_key[idx] is not None else ""
                        idx += 1
                    dim_labels[dim_key] = labels
                accumulators[dim_key].add(value)

            elif view_type == MetricViewType.SPANS:
                # Pre-filter using META denormalized fields
                if not meta_prefilter_spans(meta_item, filters):
                    continue

                # Fetch span items and span metric items
                span_items = self._table.query(
                    pk=pk, sk_prefix=f"{SK_TRACE_PREFIX}{trace_id}{SK_SPAN_PREFIX}"
                )
                smetric_items = self._table.query(
                    pk=pk, sk_prefix=f"{SK_TRACE_PREFIX}{trace_id}{SK_SPAN_METRIC_PREFIX}"
                )

                # Apply span-level filters
                span_items = filter_span_items(span_items, filters)

                for span_item in span_items:
                    time_bucket = None
                    if time_interval_seconds is not None:
                        ts = get_timestamp_for_view(view_type, span_item, meta_item)
                        time_bucket = compute_time_bucket(ts, time_interval_seconds)

                    value = extract_metric_value(
                        metric_name,
                        view_type,
                        span_item,
                        meta_item,
                        span_metric_items=smetric_items,
                    )
                    if value is None:
                        continue

                    dim_key = build_dimension_key(
                        dimensions, view_type, span_item, meta_item, trace_tags, time_bucket
                    )
                    if dimensions and any(v is None for v in dim_key):
                        continue

                    if dim_key not in accumulators:
                        accumulators[dim_key] = MetricAccumulator(collect_values=need_values)
                        labels = {}
                        idx = 0
                        if time_bucket is not None:
                            labels[TIME_BUCKET_LABEL] = time_bucket
                            idx = 1
                        for d in dimensions or []:
                            labels[d] = str(dim_key[idx]) if dim_key[idx] is not None else ""
                            idx += 1
                        dim_labels[dim_key] = labels
                    accumulators[dim_key].add(value)

            elif view_type == MetricViewType.ASSESSMENTS:
                # Fetch assessment items
                assess_items = self._table.query(
                    pk=pk, sk_prefix=f"{SK_TRACE_PREFIX}{trace_id}#ASSESS#"
                )

                # Apply assessment-level filters
                assess_items = filter_assessment_items(assess_items, filters)

                for assess_item in assess_items:
                    time_bucket = None
                    if time_interval_seconds is not None:
                        ts = get_timestamp_for_view(view_type, assess_item, meta_item)
                        time_bucket = compute_time_bucket(ts, time_interval_seconds)

                    value = extract_metric_value(metric_name, view_type, assess_item, meta_item)
                    if value is None:
                        continue

                    dim_key = build_dimension_key(
                        dimensions, view_type, assess_item, meta_item, trace_tags, time_bucket
                    )
                    if dimensions and any(v is None for v in dim_key):
                        continue

                    if dim_key not in accumulators:
                        accumulators[dim_key] = MetricAccumulator(collect_values=need_values)
                        labels = {}
                        idx = 0
                        if time_bucket is not None:
                            labels[TIME_BUCKET_LABEL] = time_bucket
                            idx = 1
                        for d in dimensions or []:
                            labels[d] = str(dim_key[idx]) if dim_key[idx] is not None else ""
                            idx += 1
                        dim_labels[dim_key] = labels
                    accumulators[dim_key].add(value)

    # 4. FINALIZE — compute aggregation results, convert to MetricDataPoint list
    from mlflow.entities.trace_metrics import MetricDataPoint

    all_data_points: list[MetricDataPoint] = []
    for dim_key, acc in accumulators.items():
        agg_values = acc.finalize(aggregations)
        labels = dim_labels.get(dim_key, {})
        dp = MetricDataPoint(
            metric_name=metric_name,
            dimensions=labels or {},
            values=agg_values,
        )
        all_data_points.append(dp)

    # 5. CACHE AND PAGINATE
    query_hash = compute_query_hash(
        experiment_ids,
        view_type,
        metric_name,
        aggregations,
        dimensions,
        filters,
        time_interval_seconds,
        start_time_ms,
        end_time_ms,
    )
    offset = 0
    if page_token:
        from mlflow_dynamodbstore.trace_metrics.pagination import decode_page_token as _decode

        token_data = _decode(page_token)
        offset = token_data.get("offset", 0)

    cache_put(self._table, query_hash, all_data_points)
    page = all_data_points[offset : offset + max_results]
    next_offset = offset + max_results
    next_token = (
        encode_page_token(query_hash, next_offset)
        if next_offset < len(all_data_points)
        else None
    )
    return PagedList(page, next_token)  # type: ignore[arg-type]

Registry Store

DynamoDBRegistryStore

DynamoDBRegistryStore(store_uri)

Bases: AbstractStore

MLflow model registry store backed by DynamoDB.

Source code in src/mlflow_dynamodbstore/registry_store.py
def __init__(self, store_uri: str) -> None:
    uri = parse_dynamodb_uri(store_uri)
    if uri.deploy:
        ensure_stack_exists(uri.table_name, uri.region, uri.endpoint_url)
    self._table = DynamoDBTable(uri.table_name, uri.region, uri.endpoint_url)
    self._cache = ResolutionCache(workspace=lambda: self._workspace)
    self._prefetch: dict[str, tuple[float, list[dict[str, Any]], dict[str, Any] | None]] = {}
    self._config = ConfigReader(self._table)
    self._config.reconcile()

supports_workspaces property

supports_workspaces

DynamoDB registry store always supports workspaces.

create_registered_model

create_registered_model(name, tags=None, description=None, deployment_job_id=None)

Create a new registered model.

Source code in src/mlflow_dynamodbstore/registry_store.py
def create_registered_model(
    self,
    name: str,
    tags: list[RegisteredModelTag] | None = None,
    description: str | None = None,
    deployment_job_id: str | None = None,
) -> RegisteredModel:
    """Create a new registered model."""
    from mlflow.prompt.registry_utils import handle_resource_already_exist_error, has_prompt_tag
    from mlflow.utils.validation import _validate_model_name

    _validate_model_name(name)
    # Check uniqueness via GSI3
    existing = self._table.query(
        pk=f"{GSI3_MODEL_NAME_PREFIX}{self._workspace}#{name}",
        index_name="gsi3",
        limit=1,
    )
    if existing:
        existing_ulid = existing[0][GSI3_SK]
        existing_tags = self._get_model_tags(existing_ulid)
        handle_resource_already_exist_error(
            name, has_prompt_tag(existing_tags), has_prompt_tag(tags)
        )

    now_ms = get_current_time_millis()
    model_ulid = generate_ulid()

    item: dict[str, Any] = {
        "PK": f"{PK_MODEL_PREFIX}{model_ulid}",
        "SK": SK_MODEL_META,
        "name": name,
        "description": description or "",
        "creation_timestamp": now_ms,
        "last_updated_timestamp": now_ms,
        "workspace": self._workspace,
        "tags": {},
        # LSI attributes
        LSI2_SK: now_ms,
        LSI3_SK: name,
        LSI4_SK: _rev(name),
        # GSI2: list models by workspace
        GSI2_PK: f"{GSI2_MODELS_PREFIX}{self._workspace}",
        GSI2_SK: f"{now_ms:020d}#{name}",
        # GSI3: unique name lookup
        GSI3_PK: f"{GSI3_MODEL_NAME_PREFIX}{self._workspace}#{name}",
        GSI3_SK: model_ulid,
        # GSI5: all model names
        GSI5_PK: f"{GSI5_MODEL_NAMES_PREFIX}{self._workspace}",
        GSI5_SK: f"{name}#{model_ulid}",
    }

    self._table.put_item(item, condition="attribute_not_exists(PK)")

    # Write NAME_REV item for suffix ILIKE support (GSI5)
    name_rev_item = {
        "PK": f"{PK_MODEL_PREFIX}{model_ulid}",
        "SK": SK_MODEL_NAME_REV,
        GSI5_PK: f"{GSI5_MODEL_NAMES_PREFIX}{self._workspace}",
        GSI5_SK: f"REV#{_rev(name.lower())}#{model_ulid}",
        "name": name,
    }
    self._table.put_item(name_rev_item)

    # Write tags if provided
    model_tags: list[RegisteredModelTag] = []
    if tags:
        for tag in tags:
            self._write_model_tag(model_ulid, tag)
            model_tags.append(tag)

    # Write FTS items for the model name
    fts_items = fts_items_for_text(
        pk=f"{PK_MODEL_PREFIX}{model_ulid}",
        entity_type="M",
        entity_id=model_ulid,
        field=None,
        text=name,
        workspace=self._workspace,
    )
    self._table.batch_write(fts_items)

    self._cache.put("model_name", name, model_ulid)
    return _item_to_registered_model(item, model_tags)

get_registered_model

get_registered_model(name)

Fetch a registered model by name.

Source code in src/mlflow_dynamodbstore/registry_store.py
def get_registered_model(self, name: str) -> RegisteredModel:
    """Fetch a registered model by name."""
    model_ulid = self._resolve_model_ulid(name)

    item = self._table.get_item(
        pk=f"{PK_MODEL_PREFIX}{model_ulid}",
        sk=SK_MODEL_META,
    )
    if item is None:
        raise MlflowException(
            f"Registered Model with name={name} not found",
            error_code=RESOURCE_DOES_NOT_EXIST,
        )

    tags = self._get_model_tags(model_ulid)
    aliases = self._aliases_for_registered_model(model_ulid)
    rm = _item_to_registered_model(item, tags, aliases)
    rm.latest_versions = self.get_latest_versions(name)
    return rm

rename_registered_model

rename_registered_model(name, new_name)

Rename a registered model.

Source code in src/mlflow_dynamodbstore/registry_store.py
def rename_registered_model(self, name: str, new_name: str) -> RegisteredModel:
    """Rename a registered model."""
    from mlflow.utils.validation import _validate_model_renaming

    _validate_model_renaming(new_name)
    model_ulid = self._resolve_model_ulid(name)

    # Check new name uniqueness
    existing = self._table.query(
        pk=f"{GSI3_MODEL_NAME_PREFIX}{self._workspace}#{new_name}",
        index_name="gsi3",
        limit=1,
    )
    if existing:
        raise MlflowException(
            f"Registered Model (name={new_name}) already exists.",
            error_code=RESOURCE_ALREADY_EXISTS,
        )

    now_ms = get_current_time_millis()

    self._table.update_item(
        pk=f"{PK_MODEL_PREFIX}{model_ulid}",
        sk=SK_MODEL_META,
        updates={
            "name": new_name,
            "last_updated_timestamp": now_ms,
            LSI2_SK: now_ms,
            LSI3_SK: new_name,
            LSI4_SK: _rev(new_name),
            GSI2_SK: f"{now_ms:020d}#{new_name}",
            GSI3_PK: f"{GSI3_MODEL_NAME_PREFIX}{self._workspace}#{new_name}",
            GSI5_SK: f"{new_name}#{model_ulid}",
        },
    )

    # Update NAME_REV item with new reversed name
    self._table.update_item(
        pk=f"{PK_MODEL_PREFIX}{model_ulid}",
        sk=SK_MODEL_NAME_REV,
        updates={
            GSI5_SK: f"REV#{_rev(new_name.lower())}#{model_ulid}",
            "name": new_name,
        },
    )

    # Delete old FTS items by querying the reverse index for this entity
    pk = f"{PK_MODEL_PREFIX}{model_ulid}"
    old_rev_items = self._table.query(
        pk=pk,
        sk_prefix=f"{SK_FTS_REV_PREFIX}M#{model_ulid}",
    )
    for rev_item in old_rev_items:
        self._table.delete_item(pk=pk, sk=rev_item["SK"])

    # Also delete the forward FTS items
    old_fts_items = self._table.query(pk=pk, sk_prefix=SK_FTS_PREFIX)
    for fts_item in old_fts_items:
        self._table.delete_item(pk=pk, sk=fts_item["SK"])

    # Write new FTS items for the new name
    new_fts_items = fts_items_for_text(
        pk=pk,
        entity_type="M",
        entity_id=model_ulid,
        field=None,
        text=new_name,
        workspace=self._workspace,
    )
    self._table.batch_write(new_fts_items)

    # Update name on all model version items
    ver_items = self._table.query(pk=pk, sk_prefix=SK_VERSION_PREFIX)
    for vi in ver_items:
        if SK_VERSION_TAG_SUFFIX not in vi["SK"]:
            self._table.update_item(pk=pk, sk=vi["SK"], updates={"name": new_name})

    # Invalidate old name cache, cache new name
    self._cache.invalidate("model_name", name)
    self._cache.put("model_name", new_name, model_ulid)

    tags = self._get_model_tags(model_ulid)
    updated_item = self._table.get_item(
        pk=f"{PK_MODEL_PREFIX}{model_ulid}",
        sk=SK_MODEL_META,
    )
    assert updated_item is not None
    return _item_to_registered_model(updated_item, tags)

update_registered_model

update_registered_model(name, description, deployment_job_id=None)

Update a registered model's description.

Source code in src/mlflow_dynamodbstore/registry_store.py
def update_registered_model(
    self,
    name: str,
    description: str,
    deployment_job_id: str | None = None,
) -> RegisteredModel:
    """Update a registered model's description."""
    model_ulid = self._resolve_model_ulid(name)
    now_ms = get_current_time_millis()

    updated_item = self._table.update_item(
        pk=f"{PK_MODEL_PREFIX}{model_ulid}",
        sk=SK_MODEL_META,
        updates={
            "description": description,
            "last_updated_timestamp": now_ms,
            LSI2_SK: now_ms,
            GSI2_SK: f"{now_ms:020d}#{name}",
        },
    )

    tags = self._get_model_tags(model_ulid)
    assert updated_item is not None
    return _item_to_registered_model(updated_item, tags)

delete_registered_model

delete_registered_model(name)

Delete a registered model and all its items.

Source code in src/mlflow_dynamodbstore/registry_store.py
def delete_registered_model(self, name: str) -> None:
    """Delete a registered model and all its items."""
    model_ulid = self._resolve_model_ulid(name)
    pk = f"{PK_MODEL_PREFIX}{model_ulid}"

    # Query all items in the partition and delete each
    items = self._table.query(pk=pk)
    for item in items:
        self._table.delete_item(pk=pk, sk=item["SK"])

    self._cache.invalidate("model_name", name)

search_registered_models

search_registered_models(filter_string=None, max_results=None, order_by=None, page_token=None)

Search registered models with filter, ordering, and pagination.

Source code in src/mlflow_dynamodbstore/registry_store.py
def search_registered_models(
    self,
    filter_string: str | None = None,
    max_results: int | None = None,
    order_by: list[str] | None = None,
    page_token: str | None = None,
) -> list[RegisteredModel]:
    """Search registered models with filter, ordering, and pagination."""

    from mlflow.store.entities import PagedList
    from mlflow.store.model_registry import (
        SEARCH_REGISTERED_MODEL_MAX_RESULTS_DEFAULT,
        SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD,
    )
    from mlflow.utils.search_utils import SearchModelUtils, SearchUtils

    from mlflow_dynamodbstore.dynamodb.search import (
        FilterPredicate,
        _compare,
    )

    if max_results is None:
        max_results = SEARCH_REGISTERED_MODEL_MAX_RESULTS_DEFAULT
    if not isinstance(max_results, int) or max_results < 1:
        raise MlflowException(
            "Invalid value for request parameter max_results. It must be at most "
            f"{SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD}, but got value {max_results}",
            INVALID_PARAMETER_VALUE,
        )
    if max_results > SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD:
        raise MlflowException(
            "Invalid value for request parameter max_results. It must be at most "
            f"{SEARCH_REGISTERED_MODEL_MAX_RESULTS_THRESHOLD}, but got value {max_results}",
            INVALID_PARAMETER_VALUE,
        )

    # Validate order_by clauses
    for clause in order_by or []:
        SearchUtils.parse_order_by_for_search_registered_models(clause)

    # Parse filters
    if filter_string:
        parsed = SearchModelUtils.parse_search_filter(filter_string)
        predicates = [
            FilterPredicate(
                field_type=p["type"],
                key=p["key"],
                op=p["comparator"],
                value=p.get("value"),
            )
            for p in parsed
        ]
    else:
        predicates = []

    name_pred = next(
        (p for p in predicates if p.field_type == "attribute" and p.key == "name"),
        None,
    )
    tag_preds = [p for p in predicates if p.field_type == "tag" and p.key != IS_PROMPT_TAG_KEY]

    # Post-filter for tags and prompts (used by non-paginated paths)
    def _post_filter(models: list[RegisteredModel]) -> list[RegisteredModel]:
        if tag_preds:
            models = self._filter_models_by_tags(models, tag_preds, _compare)
        if self._is_querying_prompt(predicates):
            models = [m for m in models if m._is_prompt()]
        else:
            models = [m for m in models if not m._is_prompt()]
        return models

    # Filter for paginated path (tags + prompts, applied per-item)
    def _paginated_filter(model: RegisteredModel) -> bool:
        if tag_preds:
            tag_dict: dict[str, str] = {}
            if hasattr(model, "_tags") and isinstance(model._tags, dict):
                tag_dict = model._tags
            elif isinstance(model.tags, dict):
                tag_dict = model.tags
            else:
                for t in model.tags:
                    tag_dict[t.key] = t.value
            for pred in tag_preds:
                if not _compare(tag_dict.get(pred.key), pred.op, pred.value):
                    return False
        if self._is_querying_prompt(predicates):
            if not model._is_prompt():
                return False
        elif model._is_prompt():
            return False
        return True

    if name_pred and name_pred.op == "=":
        # Exact name — no pagination needed
        models = _post_filter(self._search_models_by_name_exact(name_pred.value))
        return PagedList(models[:max_results], token=None)

    if name_pred and name_pred.op in ("LIKE", "ILIKE"):
        _prefix_body = name_pred.value[:-1]
        is_prefix = (
            name_pred.value.endswith("%")
            and "%" not in _prefix_body
            and "_" not in _prefix_body
        )
        if not is_prefix:
            # Non-prefix LIKE — FTS / general fallback, no pagination
            models = _post_filter(
                self._search_models_by_name_like(name_pred.value, name_pred.op)
            )
            models = self._sort_models(models, order_by or ["name ASC"])
            return PagedList(models[:max_results], token=None)
        # Prefix LIKE falls through to paginated path below

    # Paginated path: no name filter, or prefix LIKE
    name_prefix = name_pred.value[:-1] if name_pred else None
    index_name, scan_forward = self._resolve_registered_model_order(order_by)
    if index_name == "gsi5":
        pk = f"{GSI5_MODEL_NAMES_PREFIX}{self._workspace}"
        sk_prefix = name_prefix
    else:
        pk = f"{GSI2_MODELS_PREFIX}{self._workspace}"
        sk_prefix = None

    # When using GSI2 with a prefix LIKE, add name filtering
    if name_prefix and index_name != "gsi5":
        _inner = _paginated_filter

        def _name_and_paginated_filter(model: RegisteredModel) -> bool:
            if not model.name.startswith(name_prefix):
                return False
            return _inner(model)

        filter_fn = _name_and_paginated_filter
    else:
        filter_fn = _paginated_filter

    models, next_token = self._search_models_paginated(
        pk=pk,
        index_name=index_name,
        sk_prefix=sk_prefix,
        scan_forward=scan_forward,
        max_results=max_results,
        page_token=page_token,
        filter_fn=filter_fn,
        order_by=order_by,
    )
    return PagedList(models, next_token)

set_registered_model_tag

set_registered_model_tag(name, tag)

Set a tag on a registered model.

Source code in src/mlflow_dynamodbstore/registry_store.py
def set_registered_model_tag(self, name: str, tag: RegisteredModelTag) -> None:
    """Set a tag on a registered model."""
    from mlflow.utils.validation import _validate_model_name, _validate_registered_model_tag

    _validate_model_name(name)
    _validate_registered_model_tag(tag.key, tag.value)
    model_ulid = self._resolve_model_ulid(name)
    self._write_model_tag(model_ulid, tag)

delete_registered_model_tag

delete_registered_model_tag(name, key)

Delete a tag from a registered model.

Source code in src/mlflow_dynamodbstore/registry_store.py
def delete_registered_model_tag(self, name: str, key: str) -> None:
    """Delete a tag from a registered model."""
    from mlflow.utils.validation import _validate_model_name, _validate_registered_model_tag

    _validate_model_name(name)
    _validate_registered_model_tag(key, "")
    model_ulid = self._resolve_model_ulid(name)
    self._table.delete_item(
        pk=f"{PK_MODEL_PREFIX}{model_ulid}",
        sk=f"{SK_MODEL_TAG_PREFIX}{key}",
    )
    if self._config.should_denormalize(None, key):
        self._remove_denormalized_tag(
            pk=f"{PK_MODEL_PREFIX}{model_ulid}",
            sk=SK_MODEL_META,
            tag_key=key,
        )

create_model_version

create_model_version(name, source, run_id=None, tags=None, run_link=None, description=None, local_model_path=None, model_id=None)

Create a new model version under the given registered model.

Source code in src/mlflow_dynamodbstore/registry_store.py
def create_model_version(
    self,
    name: str,
    source: str,
    run_id: str | None = None,
    tags: list[Any] | None = None,
    run_link: str | None = None,
    description: str | None = None,
    local_model_path: str | None = None,
    model_id: str | None = None,
) -> ModelVersion:
    """Create a new model version under the given registered model."""
    # Resolve models:/<name>/<version> URI to actual storage location
    storage_location = source
    if source and source.startswith("models:/"):
        # Parse models:/<model_name>/<version>
        parts = source[len("models:/") :].split("/")
        if len(parts) >= 2:
            src_name, src_version = parts[0], parts[1]
            storage_location = self.get_model_version_download_uri(src_name, src_version)

    if not run_id and model_id:
        model = MlflowClient().get_logged_model(model_id)
        run_id = model.source_run_id

    model_ulid = self._resolve_model_ulid(name)
    pk = f"{PK_MODEL_PREFIX}{model_ulid}"

    # Atomically increment version counter on model META item
    counter = self._table.add_attribute(pk, SK_MODEL_META, "next_version", 1)
    next_ver = int(counter["next_version"])
    padded = _pad_version(next_ver)

    now_ms = get_current_time_millis()

    item: dict[str, Any] = {
        "PK": pk,
        "SK": f"{SK_VERSION_PREFIX}{padded}",
        "name": name,
        "version": padded,
        "source": source or "",
        "storage_location": storage_location or "",
        "run_id": run_id or "",
        "run_link": run_link or "",
        "description": description or "",
        "status": "READY",
        "current_stage": "None",
        "creation_timestamp": now_ms,
        "last_updated_timestamp": now_ms,
        "tags": {},
        # LSI attributes
        LSI1_SK: str(now_ms),
        LSI2_SK: now_ms,
        LSI3_SK: f"None#{padded}",
    }

    # Sparse LSI keys — DynamoDB rejects empty string index keys
    if source:
        item[LSI4_SK] = source.lower()
    if run_id:
        item[LSI5_SK] = f"{run_id}#{padded}"

    # GSI1: run linkage (only if run_id provided)
    if run_id:
        item[GSI1_PK] = f"{GSI1_RUN_PREFIX}{run_id}"
        item[GSI1_SK] = f"MV#{model_ulid}#{padded}"

    self._table.put_item(item)

    # Update model's last_updated_timestamp
    self._table.update_item(
        pk=pk,
        sk=SK_MODEL_META,
        updates={"last_updated_timestamp": now_ms, LSI2_SK: now_ms},
    )

    # Write tags if provided
    version_tags: list[ModelVersionTag] = []
    if tags:
        for tag in tags:
            self._write_version_tag(model_ulid, padded, tag)
            version_tags.append(tag)

    return _item_to_model_version(item, version_tags)

get_model_version

get_model_version(name, version)

Fetch a model version by name and version number.

Source code in src/mlflow_dynamodbstore/registry_store.py
def get_model_version(self, name: str, version: str) -> ModelVersion:
    """Fetch a model version by name and version number."""
    try:
        model_ulid = self._resolve_model_ulid(name)
    except MlflowException as e:
        if e.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST):
            raise MlflowException(
                f"Model Version (name={name}, version={version}) not found",
                error_code=RESOURCE_DOES_NOT_EXIST,
            ) from None
        raise
    padded = _pad_version(version)

    item = self._table.get_item(
        pk=f"{PK_MODEL_PREFIX}{model_ulid}",
        sk=f"{SK_VERSION_PREFIX}{padded}",
    )
    if item is None or item.get("current_stage") == STAGE_DELETED_INTERNAL:
        raise MlflowException(
            f"Model Version (name={name}, version={version}) not found",
            error_code=RESOURCE_DOES_NOT_EXIST,
        )

    tags = self._get_version_tags(model_ulid, padded)
    aliases = self._aliases_for_model_version(model_ulid, int(version))
    return _item_to_model_version(item, tags, aliases)

update_model_version

update_model_version(name, version, description)

Update a model version's description.

Source code in src/mlflow_dynamodbstore/registry_store.py
def update_model_version(self, name: str, version: str, description: str) -> ModelVersion:
    """Update a model version's description."""
    # Verify version exists (raises on deleted versions)
    self.get_model_version(name, version)
    model_ulid = self._resolve_model_ulid(name)
    padded = _pad_version(version)
    now_ms = get_current_time_millis()

    updated_item = self._table.update_item(
        pk=f"{PK_MODEL_PREFIX}{model_ulid}",
        sk=f"{SK_VERSION_PREFIX}{padded}",
        updates={
            "description": description,
            "last_updated_timestamp": now_ms,
            LSI2_SK: now_ms,
        },
    )

    if updated_item is None:
        raise MlflowException(
            f"Model Version (name={name}, version={version}) not found",
            error_code=RESOURCE_DOES_NOT_EXIST,
        )

    tags = self._get_version_tags(model_ulid, padded)
    return _item_to_model_version(updated_item, tags)

delete_model_version

delete_model_version(name, version)

Soft-delete a model version: redact sensitive fields, mark as deleted.

Source code in src/mlflow_dynamodbstore/registry_store.py
def delete_model_version(self, name: str, version: str) -> None:
    """Soft-delete a model version: redact sensitive fields, mark as deleted."""
    # Verify version exists (raises on deleted/missing versions)
    self.get_model_version(name, version)
    model_ulid = self._resolve_model_ulid(name)
    padded = _pad_version(version)
    pk = f"{PK_MODEL_PREFIX}{model_ulid}"
    now_ms = get_current_time_millis()

    # Soft-delete: redact sensitive fields, set deleted stage
    updates: dict[str, Any] = {
        "current_stage": STAGE_DELETED_INTERNAL,
        "source": "REDACTED-SOURCE-PATH",
        "run_id": "REDACTED-RUN-ID",
        "run_link": "REDACTED-RUN-LINK",
        "description": "",
        "status_message": "",
        "last_updated_timestamp": now_ms,
        LSI2_SK: now_ms,
        LSI3_SK: f"{STAGE_DELETED_INTERNAL}#{padded}",
    }

    # Optional TTL for eventual hard-delete
    ttl_seconds = self._config.get_soft_deleted_ttl_seconds()
    if ttl_seconds is not None:
        updates["ttl"] = int(time.time()) + ttl_seconds

    # Remove sparse index keys so deleted version won't appear in filtered queries
    removes = [LSI4_SK, LSI5_SK, GSI1_PK, GSI1_SK]

    self._table.update_item(
        pk=pk,
        sk=f"{SK_VERSION_PREFIX}{padded}",
        updates=updates,
        removes=removes,
    )

    # Hard-delete tags (no value in keeping redacted version's tags)
    tag_prefix = f"{SK_VERSION_PREFIX}{padded}{SK_VERSION_TAG_SUFFIX}"
    tag_items = self._table.query(pk=pk, sk_prefix=tag_prefix)
    for tag_item in tag_items:
        self._table.delete_item(pk=pk, sk=tag_item["SK"])

    # Delete aliases pointing to this version
    for alias_name in self._aliases_for_model_version(model_ulid, int(version)):
        self._table.delete_item(pk=pk, sk=f"{SK_MODEL_ALIAS_PREFIX}{alias_name}")

search_model_versions

search_model_versions(filter_string=None, max_results=None, order_by=None, page_token=None)

Search model versions with filter support.

Source code in src/mlflow_dynamodbstore/registry_store.py
def search_model_versions(
    self,
    filter_string: str | None = None,
    max_results: int | None = None,
    order_by: list[str] | None = None,
    page_token: str | None = None,
) -> list[ModelVersion]:
    """Search model versions with filter support."""
    from mlflow.utils.search_utils import SearchModelVersionUtils

    from mlflow_dynamodbstore.dynamodb.search import (
        FilterPredicate,
        _compare,
    )

    # Validate order_by clauses (raises on invalid columns/syntax)
    for clause in order_by or []:
        SearchModelVersionUtils.parse_order_by_for_search_model_versions(clause)

    if filter_string:
        parsed = SearchModelVersionUtils.parse_search_filter(filter_string)
        predicates = [
            FilterPredicate(
                field_type=p["type"],
                key=p["key"],
                op=p["comparator"],
                value=p.get("value"),
            )
            for p in parsed
        ]
    else:
        predicates = []

    name_pred = next(
        (p for p in predicates if p.field_type == "attribute" and p.key == "name"),
        None,
    )
    run_id_pred = next(
        (p for p in predicates if p.field_type == "attribute" and p.key == "run_id"),
        None,
    )

    # Separate prompt tag from other tag predicates — prompt filtering is handled separately
    tag_preds = [p for p in predicates if p.field_type == "tag" and p.key != IS_PROMPT_TAG_KEY]

    from mlflow.store.entities import PagedList
    from mlflow.store.model_registry import (
        SEARCH_MODEL_VERSION_MAX_RESULTS_DEFAULT,
        SEARCH_MODEL_VERSION_MAX_RESULTS_THRESHOLD,
    )

    if max_results is None:
        max_results = SEARCH_MODEL_VERSION_MAX_RESULTS_DEFAULT
    if not isinstance(max_results, int) or max_results < 1:
        raise MlflowException(
            f"Invalid value for max_results. It must be a positive integer,"
            f" but got {max_results}",
            INVALID_PARAMETER_VALUE,
        )
    if max_results > SEARCH_MODEL_VERSION_MAX_RESULTS_THRESHOLD:
        raise MlflowException(
            "Invalid value for request parameter max_results. It must be at most "
            f"{SEARCH_MODEL_VERSION_MAX_RESULTS_THRESHOLD}, but got value {max_results}",
            INVALID_PARAMETER_VALUE,
        )

    # Resolve model name(s) from filter
    model_name: str | None = None
    model_names: list[str] | None = None
    if name_pred and name_pred.op == "=":
        model_name = name_pred.value
    elif name_pred and name_pred.op in ("LIKE", "ILIKE"):
        # For prefix LIKE, find all matching models via GSI5
        pattern = name_pred.value
        if pattern.endswith("%") and "%" not in pattern[:-1]:
            prefix = pattern[:-1]
            items = self._table.query(
                pk=f"{GSI5_MODEL_NAMES_PREFIX}{self._workspace}",
                sk_prefix=prefix,
                index_name="gsi5",
            )
            model_names = [
                item["name"]
                for item in items
                if item.get("name") and not item.get(GSI5_SK, "").startswith("REV#")
            ]
            # Optimize: single match → use single-model path
            if len(model_names) == 1:
                model_name = model_names[0]
                model_names = None

    run_id_filter = run_id_pred.value if run_id_pred and run_id_pred.op == "=" else None

    # Build DynamoDB-native filters from attribute predicates
    from boto3.dynamodb.conditions import Attr

    # SK constraints from version_number (used by main-table queries)
    version_pred = next(
        (p for p in predicates if p.field_type == "attribute" and p.key == "version_number"),
        None,
    )
    sk_prefix: str | None = SK_VERSION_PREFIX
    sk_gte: str | None = None
    sk_lte: str | None = None
    if version_pred:
        padded_val = _pad_version(version_pred.value)
        if version_pred.op == "=":
            sk_prefix = f"{SK_VERSION_PREFIX}{padded_val}"
        elif version_pred.op == "<=":
            sk_gte = SK_VERSION_PREFIX
            sk_lte = f"{SK_VERSION_PREFIX}{padded_val}"
            sk_prefix = None
        elif version_pred.op == ">=":
            sk_gte = f"{SK_VERSION_PREFIX}{padded_val}"
            sk_lte = f"{SK_VERSION_PREFIX}99999999"
            sk_prefix = None
        elif version_pred.op == "<":
            sk_gte = SK_VERSION_PREFIX
            sk_lte = f"{SK_VERSION_PREFIX}{_pad_version(int(version_pred.value) - 1)}"
            sk_prefix = None
        elif version_pred.op == ">":
            sk_gte = f"{SK_VERSION_PREFIX}{_pad_version(int(version_pred.value) + 1)}"
            sk_lte = f"{SK_VERSION_PREFIX}99999999"
            sk_prefix = None

    # FilterExpression for source_path, run_id, and version_number (used by LSI2 queries)
    filter_expr: Any = None
    for p in predicates:
        if p.field_type != "attribute" or p.key == "name":
            continue
        # run_id exact match handled by GSI1 path when no model_name
        if p.key == "run_id" and p.op == "=" and not model_name:
            continue
        dynamo_field = _MV_ATTRIBUTE_MAP.get(p.key, p.key)
        val = _pad_version(p.value) if p.key == "version_number" else p.value
        condition: Any = None
        if p.op == "=":
            condition = Attr(dynamo_field).eq(val)
        elif p.op == "!=":
            condition = Attr(dynamo_field).ne(val)
        elif p.op == "<=":
            condition = Attr(dynamo_field).lte(val)
        elif p.op == ">=":
            condition = Attr(dynamo_field).gte(val)
        elif p.op == "<":
            condition = Attr(dynamo_field).lt(val)
        elif p.op == ">":
            condition = Attr(dynamo_field).gt(val)
        elif p.op == "IN":
            condition = Attr(dynamo_field).is_in(val)
        elif p.op in ("LIKE", "ILIKE"):
            pattern = val if isinstance(val, str) else str(val)
            if p.op == "ILIKE":
                pattern = pattern.lower()
            if pattern.endswith("%") and "%" not in pattern[:-1]:
                condition = Attr(dynamo_field).begins_with(pattern[:-1])
        if condition is not None:
            filter_expr = filter_expr & condition if filter_expr else condition

    # Build filter_fn for tag and prompt filtering
    def version_filter_fn(mv: ModelVersion) -> bool:
        if tag_preds:
            tag_dict = mv.tags or {}
            for pred in tag_preds:
                actual = tag_dict.get(pred.key)
                if not _compare(actual, pred.op, pred.value):
                    return False
        if self._is_querying_prompt(predicates):
            if not self._version_is_prompt(mv):
                return False
        elif self._version_is_prompt(mv):
            return False
        return True

    if model_name:
        # Single model: use LSI2 for native timestamp DESC ordering + pagination
        # LSI2 can't use SK constraints, so use FilterExpression for all attributes
        try:
            model_ulid = self._resolve_model_ulid(model_name)
        except MlflowException:
            return PagedList([], token=None)
        versions, next_token = self._get_versions_for_model(
            model_ulid,
            model_name,
            max_results=max_results,
            page_token=page_token,
            filter_fn=version_filter_fn,
            filter_expression=filter_expr,
        )
        # Apply run_id post-filter if also specified
        if run_id_filter:
            versions = [v for v in versions if v.run_id == run_id_filter]
        return PagedList(versions, next_token)

    if model_names:
        # Multiple models from LIKE prefix: query each, no pagination
        multi_versions: list[ModelVersion] = []
        for mn in model_names:
            try:
                ulid = self._resolve_model_ulid(mn)
            except MlflowException:
                continue
            mvs, _ = self._get_versions_for_model(
                ulid,
                mn,
                filter_fn=version_filter_fn,
                filter_expression=filter_expr,
            )
            multi_versions.extend(mvs)
        if run_id_filter:
            multi_versions = [v for v in multi_versions if v.run_id == run_id_filter]
        if max_results is not None:
            multi_versions = multi_versions[:max_results]
        return PagedList(multi_versions, token=None)

    # Non-single-model paths: use SK constraints + FilterExpression on main table
    if run_id_filter:
        versions = self._search_versions_by_run_id(run_id_filter)
    else:
        versions = self._list_all_versions(
            sk_prefix=sk_prefix,
            sk_gte=sk_gte,
            sk_lte=sk_lte,
            filter_expression=filter_expr,
        )

    # Apply run_id post-filter if name was also specified
    if run_id_filter and model_name:
        versions = [v for v in versions if v.run_id == run_id_filter]

    # Apply tag filters
    if tag_preds:
        versions = self._filter_versions_by_tags(versions, tag_preds, _compare)

    # Filter by prompt status
    if self._is_querying_prompt(predicates):
        versions = [v for v in versions if self._version_is_prompt(v)]
    else:
        versions = [v for v in versions if not self._version_is_prompt(v)]

    if max_results is not None:
        versions = versions[:max_results]

    return PagedList(versions, token=None)

get_latest_versions

get_latest_versions(name, stages=None)

Get the latest version for each requested stage.

Source code in src/mlflow_dynamodbstore/registry_store.py
def get_latest_versions(self, name: str, stages: list[str] | None = None) -> list[ModelVersion]:
    """Get the latest version for each requested stage."""
    model_ulid = self._resolve_model_ulid(name)
    pk = f"{PK_MODEL_PREFIX}{model_ulid}"

    if not stages:
        # Get all versions and determine unique stages
        ver_items = self._table.query(pk=pk, sk_prefix=SK_VERSION_PREFIX)
        ver_items = [
            vi
            for vi in ver_items
            if SK_VERSION_TAG_SUFFIX not in vi["SK"]
            and vi.get("current_stage") != STAGE_DELETED_INTERNAL
        ]
        # Group by stage, pick latest (highest version) per stage
        stage_latest: dict[str, dict[str, Any]] = {}
        for vi in ver_items:
            stage = vi.get("current_stage", "None")
            existing_item = stage_latest.get(stage)
            if existing_item is None or vi["version"] > existing_item["version"]:
                stage_latest[stage] = vi
        results = []
        for vi in stage_latest.values():
            padded = vi["SK"].replace(SK_VERSION_PREFIX, "")
            tags = self._get_version_tags(model_ulid, padded)
            results.append(_item_to_model_version(vi, tags))
        return results

    from mlflow.entities.model_registry.model_version_stages import (
        get_canonical_stage,
    )

    results = []
    for stage in stages:
        canonical = get_canonical_stage(stage)
        # Query LSI3 where lsi3sk begins_with "CanonicalStage#"
        stage_items = self._table.query(
            pk=pk,
            sk_prefix=f"{canonical}#",
            index_name="lsi3",
            scan_forward=False,
            limit=1,
        )
        if stage_items:
            vi = stage_items[0]
            if vi.get("current_stage") == STAGE_DELETED_INTERNAL:
                continue
            padded = vi["SK"].replace(SK_VERSION_PREFIX, "")
            tags = self._get_version_tags(model_ulid, padded)
            results.append(_item_to_model_version(vi, tags))
    return results

set_model_version_tag

set_model_version_tag(name, version, tag)

Set a tag on a model version.

Source code in src/mlflow_dynamodbstore/registry_store.py
def set_model_version_tag(self, name: str, version: str, tag: Any) -> None:
    """Set a tag on a model version."""
    from mlflow.utils.validation import (
        _validate_model_name,
        _validate_model_version,
        _validate_model_version_tag,
    )

    _validate_model_name(name)
    _validate_model_version(version)
    _validate_model_version_tag(tag.key, tag.value)
    # Verify version exists (raises on deleted versions)
    self.get_model_version(name, version)
    model_ulid = self._resolve_model_ulid(name)
    padded = _pad_version(version)
    self._write_version_tag(model_ulid, padded, tag)

delete_model_version_tag

delete_model_version_tag(name, version, key)

Delete a tag from a model version.

Source code in src/mlflow_dynamodbstore/registry_store.py
def delete_model_version_tag(self, name: str, version: str, key: str) -> None:
    """Delete a tag from a model version."""
    from mlflow.utils.validation import (
        _validate_model_name,
        _validate_model_version,
        _validate_model_version_tag,
    )

    _validate_model_name(name)
    _validate_model_version(version)
    _validate_model_version_tag(key, "")
    # Verify version exists (raises on deleted versions)
    self.get_model_version(name, version)
    model_ulid = self._resolve_model_ulid(name)
    padded = _pad_version(version)
    self._table.delete_item(
        pk=f"{PK_MODEL_PREFIX}{model_ulid}",
        sk=f"{SK_VERSION_PREFIX}{padded}{SK_VERSION_TAG_SUFFIX}{key}",
    )
    if self._config.should_denormalize(None, key):
        self._remove_denormalized_tag(
            pk=f"{PK_MODEL_PREFIX}{model_ulid}",
            sk=f"{SK_VERSION_PREFIX}{padded}",
            tag_key=key,
        )

get_model_version_download_uri

get_model_version_download_uri(name, version)

Return the download URI for a model version.

Returns storage_location if set (resolved from models:/ URI), otherwise falls back to source.

Source code in src/mlflow_dynamodbstore/registry_store.py
def get_model_version_download_uri(self, name: str, version: str) -> str:
    """Return the download URI for a model version.

    Returns ``storage_location`` if set (resolved from ``models:/`` URI),
    otherwise falls back to ``source``.
    """
    try:
        model_ulid = self._resolve_model_ulid(name)
    except MlflowException:
        raise MlflowException(
            f"Model Version (name={name}, version={version}) not found",
            error_code=RESOURCE_DOES_NOT_EXIST,
        ) from None
    padded = _pad_version(version)
    item = self._table.get_item(
        pk=f"{PK_MODEL_PREFIX}{model_ulid}",
        sk=f"{SK_VERSION_PREFIX}{padded}",
    )
    if item is None or item.get("current_stage") == STAGE_DELETED_INTERNAL:
        raise MlflowException(
            f"Model Version (name={name}, version={version}) not found",
            error_code=RESOURCE_DOES_NOT_EXIST,
        )
    return item.get("storage_location") or item.get("source") or ""

transition_model_version_stage

transition_model_version_stage(name, version, stage, archive_existing_versions)

Transition a model version to a new stage.

Source code in src/mlflow_dynamodbstore/registry_store.py
def transition_model_version_stage(
    self,
    name: str,
    version: str,
    stage: str,
    archive_existing_versions: bool,
) -> ModelVersion:
    """Transition a model version to a new stage."""
    from mlflow.entities.model_registry.model_version_stages import (
        DEFAULT_STAGES_FOR_GET_LATEST_VERSIONS,
        get_canonical_stage,
    )

    canonical_stage = get_canonical_stage(stage)

    is_active = canonical_stage in DEFAULT_STAGES_FOR_GET_LATEST_VERSIONS
    if archive_existing_versions and not is_active:
        raise MlflowException(
            f"Model version transition cannot archive existing model versions "
            f"because '{stage}' is not an Active stage. Valid stages are "
            f"{', '.join(DEFAULT_STAGES_FOR_GET_LATEST_VERSIONS)}",
            error_code=INVALID_PARAMETER_VALUE,
        )

    # Verify version exists (raises on deleted versions)
    self.get_model_version(name, version)
    model_ulid = self._resolve_model_ulid(name)
    padded = _pad_version(version)
    now_ms = get_current_time_millis()
    pk = f"{PK_MODEL_PREFIX}{model_ulid}"

    # Archive other versions in active stages if requested
    if archive_existing_versions:
        version_items = self._table.query(pk=pk, sk_prefix=SK_VERSION_PREFIX)
        for vi in version_items:
            if SK_VERSION_TAG_SUFFIX in vi["SK"]:
                continue
            vi_padded = vi["SK"].replace(SK_VERSION_PREFIX, "")
            if vi_padded == padded:
                continue
            vi_stage = vi.get("current_stage", "None")
            if vi_stage == STAGE_DELETED_INTERNAL:
                continue
            if vi_stage == canonical_stage:
                self._table.update_item(
                    pk=pk,
                    sk=vi["SK"],
                    updates={
                        "current_stage": "Archived",
                        "last_updated_timestamp": now_ms,
                        LSI2_SK: now_ms,
                        LSI3_SK: f"Archived#{vi_padded}",
                    },
                )

    updated_item = self._table.update_item(
        pk=pk,
        sk=f"{SK_VERSION_PREFIX}{padded}",
        updates={
            "current_stage": canonical_stage,
            "last_updated_timestamp": now_ms,
            LSI2_SK: now_ms,
            LSI3_SK: f"{canonical_stage}#{padded}",
        },
    )

    if updated_item is None:
        raise MlflowException(
            f"Model Version (name={name}, version={version}) not found",
            error_code=RESOURCE_DOES_NOT_EXIST,
        )

    tags = self._get_version_tags(model_ulid, padded)
    return _item_to_model_version(updated_item, tags)

set_registered_model_alias

set_registered_model_alias(name, alias, version)

Set an alias pointing to a specific version of a registered model.

Source code in src/mlflow_dynamodbstore/registry_store.py
def set_registered_model_alias(self, name: str, alias: str, version: str) -> None:
    """Set an alias pointing to a specific version of a registered model."""
    model_ulid = self._resolve_model_ulid(name)
    # Verify the version exists by fetching it (raises if not found)
    self.get_model_version(name, version)
    item = {
        "PK": f"{PK_MODEL_PREFIX}{model_ulid}",
        "SK": f"{SK_MODEL_ALIAS_PREFIX}{alias}",
        "alias": alias,
        "version": version,
    }
    self._table.put_item(item)

delete_registered_model_alias

delete_registered_model_alias(name, alias)

Delete an alias from a registered model (no-op if alias does not exist).

Source code in src/mlflow_dynamodbstore/registry_store.py
def delete_registered_model_alias(self, name: str, alias: str) -> None:
    """Delete an alias from a registered model (no-op if alias does not exist)."""
    model_ulid = self._resolve_model_ulid(name)
    self._table.delete_item(
        pk=f"{PK_MODEL_PREFIX}{model_ulid}",
        sk=f"{SK_MODEL_ALIAS_PREFIX}{alias}",
    )

get_model_version_by_alias

get_model_version_by_alias(name, alias)

Return the model version that the given alias resolves to.

Source code in src/mlflow_dynamodbstore/registry_store.py
def get_model_version_by_alias(self, name: str, alias: str) -> ModelVersion:
    """Return the model version that the given alias resolves to."""
    model_ulid = self._resolve_model_ulid(name)
    item = self._table.get_item(
        pk=f"{PK_MODEL_PREFIX}{model_ulid}",
        sk=f"{SK_MODEL_ALIAS_PREFIX}{alias}",
    )
    if item is None:
        raise MlflowException(
            f"Registered model alias {alias} not found.",
            error_code=RESOURCE_DOES_NOT_EXIST,
        )
    version: str = item["version"]
    return self.get_model_version(name, version)

Workspace Store

DynamoDBWorkspaceStore

DynamoDBWorkspaceStore(workspace_uri=None, store_uri=None)

Bases: AbstractStore

Workspace provider backed by DynamoDB.

Implements the MLflow AbstractStore interface for workspaces. Workspace items are stored as:

PK = WORKSPACE#<name>
SK = META

They are indexed on GSI2 with:

gsi2pk = WORKSPACES
gsi2sk = <name>
Source code in src/mlflow_dynamodbstore/workspace_store.py
def __init__(self, workspace_uri: str | None = None, store_uri: str | None = None) -> None:
    resolved_uri = workspace_uri or store_uri
    if resolved_uri is None:
        raise TypeError("DynamoDBWorkspaceStore requires 'workspace_uri' argument")
    uri = parse_dynamodb_uri(resolved_uri)
    if uri.deploy:
        ensure_stack_exists(uri.table_name, uri.region, uri.endpoint_url)
    self._table = DynamoDBTable(uri.table_name, uri.region, uri.endpoint_url)
    # Ensure the default workspace exists on startup
    self._ensure_default_workspace()

create_workspace

create_workspace(workspace)

Create a new workspace.

Raises:

Type Description
MlflowException

If a workspace with the given name already exists.

Source code in src/mlflow_dynamodbstore/workspace_store.py
def create_workspace(self, workspace: Workspace) -> Workspace:
    """Create a new workspace.

    Raises:
        MlflowException: If a workspace with the given name already exists.
    """
    from botocore.exceptions import ClientError
    from mlflow.store.workspace.abstract_store import WorkspaceNameValidator

    WorkspaceNameValidator.validate(workspace.name)
    try:
        self._put_workspace(
            workspace.name,
            description=workspace.description or "",
            default_artifact_root=workspace.default_artifact_root or "",
            condition="attribute_not_exists(PK)",
        )
    except ClientError as e:
        if e.response["Error"]["Code"] == "ConditionalCheckFailedException":
            raise MlflowException(
                f"Workspace '{workspace.name}' already exists.",
                error_code=RESOURCE_ALREADY_EXISTS,
            ) from None
        raise
    return workspace

get_workspace

get_workspace(workspace_name)

Return a Workspace entity.

Raises:

Type Description
MlflowException

If the workspace does not exist.

Source code in src/mlflow_dynamodbstore/workspace_store.py
def get_workspace(self, workspace_name: str) -> Workspace:
    """Return a Workspace entity.

    Raises:
        MlflowException: If the workspace does not exist.
    """
    item = self._table.get_item(
        pk=f"{PK_WORKSPACE_PREFIX}{workspace_name}",
        sk=SK_WORKSPACE_META,
    )
    if item is None:
        raise MlflowException(
            f"Workspace '{workspace_name}' not found",
            error_code=RESOURCE_DOES_NOT_EXIST,
        )
    return self._item_to_workspace(item)

list_workspaces

list_workspaces()

Return all workspaces as Workspace entities.

Source code in src/mlflow_dynamodbstore/workspace_store.py
def list_workspaces(self) -> list[Workspace]:
    """Return all workspaces as Workspace entities."""
    items = self._table.query(
        pk=GSI2_WORKSPACES,
        index_name="gsi2",
    )
    return [self._item_to_workspace(item) for item in items]

update_workspace

update_workspace(workspace)

Update mutable workspace attributes.

An empty string for default_artifact_root signals "clear this field".

Source code in src/mlflow_dynamodbstore/workspace_store.py
def update_workspace(self, workspace: Workspace) -> Workspace:
    """Update mutable workspace attributes.

    An empty string for ``default_artifact_root`` signals "clear this field".
    """
    updates: dict[str, Any] = {}
    if workspace.description is not None:
        updates["description"] = workspace.description
    if workspace.default_artifact_root is not None:
        # Empty string means "clear" — store empty, _item_to_workspace returns None
        updates["default_artifact_root"] = workspace.default_artifact_root

    if updates:
        self._table.update_item(
            pk=f"{PK_WORKSPACE_PREFIX}{workspace.name}",
            sk=SK_WORKSPACE_META,
            updates=updates,
        )
    return self.get_workspace(workspace.name)

delete_workspace

delete_workspace(workspace_name, mode=RESTRICT)

Delete a workspace.

Raises:

Type Description
MlflowException

If attempting to delete the built-in 'default' workspace.

Source code in src/mlflow_dynamodbstore/workspace_store.py
def delete_workspace(
    self,
    workspace_name: str,
    mode: WorkspaceDeletionMode = WorkspaceDeletionMode.RESTRICT,
) -> None:
    """Delete a workspace.

    Raises:
        MlflowException: If attempting to delete the built-in 'default' workspace.
    """
    if workspace_name == "default":
        raise MlflowException(
            f"Cannot delete the reserved '{workspace_name}' workspace",
            error_code=INVALID_STATE,
        )
    if mode == WorkspaceDeletionMode.RESTRICT:
        self._check_workspace_empty(workspace_name)
    self._table.delete_item(
        pk=f"{PK_WORKSPACE_PREFIX}{workspace_name}",
        sk=SK_WORKSPACE_META,
    )

get_default_workspace

get_default_workspace()

Return the default workspace.

Source code in src/mlflow_dynamodbstore/workspace_store.py
def get_default_workspace(self) -> Workspace:
    """Return the default workspace."""
    return self.get_workspace("default")

resolve_artifact_root

resolve_artifact_root(default_artifact_root, workspace_name)

Allow per-workspace artifact storage roots.

Returns the workspace's default_artifact_root if configured, otherwise falls back to the server's default_artifact_root.

Source code in src/mlflow_dynamodbstore/workspace_store.py
def resolve_artifact_root(
    self, default_artifact_root: str | None, workspace_name: str
) -> tuple[str | None, bool]:
    """Allow per-workspace artifact storage roots.

    Returns the workspace's ``default_artifact_root`` if configured, otherwise
    falls back to the server's ``default_artifact_root``.
    """
    try:
        ws = self.get_workspace(workspace_name)
        if ws.default_artifact_root:
            return ws.default_artifact_root, False
    except MlflowException:
        pass
    return default_artifact_root, True

Auth Store

store

DynamoDB-backed MLflow auth store.

DynamoDBAuthStore

DynamoDBAuthStore(store_uri)

Auth store backed by DynamoDB.

User items are stored as:

PK = USER#<username>
SK = U#META

They are indexed on GSI2 with:

gsi2pk = AUTH_USERS
gsi2sk = <username>
Source code in src/mlflow_dynamodbstore/auth/store.py
def __init__(self, store_uri: str) -> None:
    uri = parse_dynamodb_uri(store_uri)
    if uri.deploy:
        ensure_stack_exists(uri.table_name, uri.region, uri.endpoint_url)
    self._table = DynamoDBTable(uri.table_name, uri.region, uri.endpoint_url)

init_db

init_db(*args, **kwargs)

No-op — table is created by the provisioner.

Source code in src/mlflow_dynamodbstore/auth/store.py
def init_db(self, *args: Any, **kwargs: Any) -> None:
    """No-op — table is created by the provisioner."""
    pass

create_user

create_user(username, password, is_admin=False)

Create a new user.

Raises MlflowException if a user with the given username already exists.

Source code in src/mlflow_dynamodbstore/auth/store.py
def create_user(self, username: str, password: str, is_admin: bool = False) -> User:
    """Create a new user.

    Raises MlflowException if a user with the given username already exists.
    """
    password_hash = generate_password_hash(password)
    item: dict[str, Any] = {
        "PK": f"{PK_USER_PREFIX}{username}",
        "SK": SK_USER_META,
        "password_hash": password_hash,
        "is_admin": is_admin,
        # GSI2: list all users
        GSI2_PK: GSI2_AUTH_USERS,
        GSI2_SK: username,
    }
    try:
        self._table.put_item(item, condition="attribute_not_exists(PK)")
    except Exception as exc:
        if "ConditionalCheckFailedException" in str(exc):
            raise MlflowException(  # type: ignore[no-untyped-call]
                f"User '{username}' already exists.",
                error_code=RESOURCE_ALREADY_EXISTS,
            ) from exc
        raise

    return User(  # type: ignore[no-untyped-call]
        id_=_user_id_from_username(username),
        username=username,
        password_hash=password_hash,
        is_admin=is_admin,
        experiment_permissions=[],
        registered_model_permissions=[],
    )

authenticate_user

authenticate_user(username, password)

Verify a user's password. Returns True if valid, False otherwise.

Source code in src/mlflow_dynamodbstore/auth/store.py
def authenticate_user(self, username: str, password: str) -> bool:
    """Verify a user's password. Returns True if valid, False otherwise."""
    item = self._table.get_item(
        pk=f"{PK_USER_PREFIX}{username}",
        sk=SK_USER_META,
        consistent=True,
    )
    if item is None:
        return False
    return check_password_hash(item.get("password_hash", ""), password)

has_user

has_user(username)

Return True if the user exists.

Source code in src/mlflow_dynamodbstore/auth/store.py
def has_user(self, username: str) -> bool:
    """Return True if the user exists."""
    item = self._table.get_item(
        pk=f"{PK_USER_PREFIX}{username}",
        sk=SK_USER_META,
    )
    return item is not None

get_user

get_user(username)

Return a User entity.

Raises MlflowException if the user does not exist.

Source code in src/mlflow_dynamodbstore/auth/store.py
def get_user(self, username: str) -> User:
    """Return a User entity.

    Raises MlflowException if the user does not exist.
    """
    item = self._table.get_item(
        pk=f"{PK_USER_PREFIX}{username}",
        sk=SK_USER_META,
        consistent=True,
    )
    if item is None:
        raise MlflowException(  # type: ignore[no-untyped-call]
            f"User '{username}' not found.",
            error_code=RESOURCE_DOES_NOT_EXIST,
        )
    return _item_to_user(item)

list_users

list_users()

Return all users.

Source code in src/mlflow_dynamodbstore/auth/store.py
def list_users(self) -> list[User]:
    """Return all users."""
    items = self._table.query(
        pk=GSI2_AUTH_USERS,
        index_name=GSI2_NAME,
    )
    return [_item_to_user(item) for item in items]

update_user

update_user(username, password=None, is_admin=None)

Update a user's password and/or admin status.

Source code in src/mlflow_dynamodbstore/auth/store.py
def update_user(
    self,
    username: str,
    password: str | None = None,
    is_admin: bool | None = None,
) -> None:
    """Update a user's password and/or admin status."""
    updates: dict[str, Any] = {}
    if password is not None:
        updates["password_hash"] = generate_password_hash(password)
    if is_admin is not None:
        updates["is_admin"] = is_admin

    if updates:
        self._table.update_item(
            pk=f"{PK_USER_PREFIX}{username}",
            sk=SK_USER_META,
            updates=updates,
        )

delete_user

delete_user(username)

Delete a user and all associated items in the USER# partition.

Source code in src/mlflow_dynamodbstore/auth/store.py
def delete_user(self, username: str) -> None:
    """Delete a user and all associated items in the USER# partition."""
    pk = f"{PK_USER_PREFIX}{username}"
    # Query all items in this user's partition and delete them
    items = self._table.query(pk=pk)
    for item in items:
        self._table.delete_item(pk=pk, sk=item["SK"])

create_experiment_permission

create_experiment_permission(experiment_id, username, permission)

Create an experiment permission for a user.

Raises MlflowException if the permission already exists.

Source code in src/mlflow_dynamodbstore/auth/store.py
def create_experiment_permission(
    self, experiment_id: str, username: str, permission: str
) -> ExperimentPermission:
    """Create an experiment permission for a user.

    Raises MlflowException if the permission already exists.
    """
    pk = f"{PK_USER_PREFIX}{username}"
    sk = f"{SK_USER_PERM_PREFIX}EXP#{experiment_id}"
    item: dict[str, Any] = {
        "PK": pk,
        "SK": sk,
        "permission": permission,
        # GSI4: look up all users with permissions on an experiment
        GSI4_PK: f"{GSI4_PERM_PREFIX}EXP#{experiment_id}",
        GSI4_SK: f"{PK_USER_PREFIX}{username}",
    }
    try:
        self._table.put_item(item, condition="attribute_not_exists(PK)")
    except Exception as exc:
        if "ConditionalCheckFailedException" in str(exc):
            raise MlflowException(  # type: ignore[no-untyped-call]
                f"Permission for experiment '{experiment_id}' "
                f"and user '{username}' already exists.",
                error_code=RESOURCE_ALREADY_EXISTS,
            ) from exc
        raise

    return ExperimentPermission(  # type: ignore[no-untyped-call]
        experiment_id=experiment_id,
        user_id=_user_id_from_username(username),
        permission=permission,
    )

get_experiment_permission

get_experiment_permission(experiment_id, username)

Return an experiment permission.

Raises MlflowException if the permission does not exist.

Source code in src/mlflow_dynamodbstore/auth/store.py
def get_experiment_permission(self, experiment_id: str, username: str) -> ExperimentPermission:
    """Return an experiment permission.

    Raises MlflowException if the permission does not exist.
    """
    pk = f"{PK_USER_PREFIX}{username}"
    sk = f"{SK_USER_PERM_PREFIX}EXP#{experiment_id}"
    item = self._table.get_item(pk=pk, sk=sk, consistent=True)
    if item is None:
        raise MlflowException(  # type: ignore[no-untyped-call]
            f"Permission for experiment '{experiment_id}' and user '{username}' not found.",
            error_code=RESOURCE_DOES_NOT_EXIST,
        )
    return ExperimentPermission(  # type: ignore[no-untyped-call]
        experiment_id=experiment_id,
        user_id=_user_id_from_username(username),
        permission=item["permission"],
    )

list_experiment_permissions

list_experiment_permissions(username)

Return all experiment permissions for a user.

Source code in src/mlflow_dynamodbstore/auth/store.py
def list_experiment_permissions(self, username: str) -> list[ExperimentPermission]:
    """Return all experiment permissions for a user."""
    pk = f"{PK_USER_PREFIX}{username}"
    items = self._table.query(pk=pk, sk_prefix=f"{SK_USER_PERM_PREFIX}EXP#")
    user_id = _user_id_from_username(username)
    return [
        ExperimentPermission(  # type: ignore[no-untyped-call]
            experiment_id=item["SK"].removeprefix(f"{SK_USER_PERM_PREFIX}EXP#"),
            user_id=user_id,
            permission=item["permission"],
        )
        for item in items
    ]

update_experiment_permission

update_experiment_permission(experiment_id, username, permission)

Update an experiment permission.

Source code in src/mlflow_dynamodbstore/auth/store.py
def update_experiment_permission(
    self, experiment_id: str, username: str, permission: str
) -> None:
    """Update an experiment permission."""
    pk = f"{PK_USER_PREFIX}{username}"
    sk = f"{SK_USER_PERM_PREFIX}EXP#{experiment_id}"
    self._table.update_item(pk=pk, sk=sk, updates={"permission": permission})

delete_experiment_permission

delete_experiment_permission(experiment_id, username)

Delete an experiment permission.

Source code in src/mlflow_dynamodbstore/auth/store.py
def delete_experiment_permission(self, experiment_id: str, username: str) -> None:
    """Delete an experiment permission."""
    pk = f"{PK_USER_PREFIX}{username}"
    sk = f"{SK_USER_PERM_PREFIX}EXP#{experiment_id}"
    self._table.delete_item(pk=pk, sk=sk)

create_registered_model_permission

create_registered_model_permission(name, username, permission)

Create a registered model permission for a user.

Raises MlflowException if the permission already exists.

Source code in src/mlflow_dynamodbstore/auth/store.py
def create_registered_model_permission(
    self, name: str, username: str, permission: str
) -> RegisteredModelPermission:
    """Create a registered model permission for a user.

    Raises MlflowException if the permission already exists.
    """
    pk = f"{PK_USER_PREFIX}{username}"
    sk = f"{SK_USER_PERM_PREFIX}MODEL#default#{name}"
    item: dict[str, Any] = {
        "PK": pk,
        "SK": sk,
        "permission": permission,
        GSI4_PK: f"{GSI4_PERM_PREFIX}MODEL#default#{name}",
        GSI4_SK: f"{PK_USER_PREFIX}{username}",
    }
    try:
        self._table.put_item(item, condition="attribute_not_exists(PK)")
    except Exception as exc:
        if "ConditionalCheckFailedException" in str(exc):
            raise MlflowException(  # type: ignore[no-untyped-call]
                f"Permission for model '{name}' and user '{username}' already exists.",
                error_code=RESOURCE_ALREADY_EXISTS,
            ) from exc
        raise

    return RegisteredModelPermission(  # type: ignore[no-untyped-call]
        name=name,
        user_id=_user_id_from_username(username),
        permission=permission,
    )

get_registered_model_permission

get_registered_model_permission(name, username)

Return a registered model permission.

Raises MlflowException if the permission does not exist.

Source code in src/mlflow_dynamodbstore/auth/store.py
def get_registered_model_permission(
    self, name: str, username: str
) -> RegisteredModelPermission:
    """Return a registered model permission.

    Raises MlflowException if the permission does not exist.
    """
    pk = f"{PK_USER_PREFIX}{username}"
    sk = f"{SK_USER_PERM_PREFIX}MODEL#default#{name}"
    item = self._table.get_item(pk=pk, sk=sk, consistent=True)
    if item is None:
        raise MlflowException(  # type: ignore[no-untyped-call]
            f"Permission for model '{name}' and user '{username}' not found.",
            error_code=RESOURCE_DOES_NOT_EXIST,
        )
    return RegisteredModelPermission(  # type: ignore[no-untyped-call]
        name=name,
        user_id=_user_id_from_username(username),
        permission=item["permission"],
    )

list_registered_model_permissions

list_registered_model_permissions(username)

Return all registered model permissions for a user.

Source code in src/mlflow_dynamodbstore/auth/store.py
def list_registered_model_permissions(self, username: str) -> list[RegisteredModelPermission]:
    """Return all registered model permissions for a user."""
    pk = f"{PK_USER_PREFIX}{username}"
    items = self._table.query(pk=pk, sk_prefix=f"{SK_USER_PERM_PREFIX}MODEL#default#")
    user_id = _user_id_from_username(username)
    return [
        RegisteredModelPermission(  # type: ignore[no-untyped-call]
            name=item["SK"].removeprefix(f"{SK_USER_PERM_PREFIX}MODEL#default#"),
            user_id=user_id,
            permission=item["permission"],
        )
        for item in items
    ]

update_registered_model_permission

update_registered_model_permission(name, username, permission)

Update a registered model permission.

Source code in src/mlflow_dynamodbstore/auth/store.py
def update_registered_model_permission(self, name: str, username: str, permission: str) -> None:
    """Update a registered model permission."""
    pk = f"{PK_USER_PREFIX}{username}"
    sk = f"{SK_USER_PERM_PREFIX}MODEL#default#{name}"
    self._table.update_item(pk=pk, sk=sk, updates={"permission": permission})

delete_registered_model_permission

delete_registered_model_permission(name, username)

Delete a registered model permission.

Source code in src/mlflow_dynamodbstore/auth/store.py
def delete_registered_model_permission(self, name: str, username: str) -> None:
    """Delete a registered model permission."""
    pk = f"{PK_USER_PREFIX}{username}"
    sk = f"{SK_USER_PERM_PREFIX}MODEL#default#{name}"
    self._table.delete_item(pk=pk, sk=sk)

delete_registered_model_permissions

delete_registered_model_permissions(name)

Bulk-delete all permissions for a registered model (all users).

Source code in src/mlflow_dynamodbstore/auth/store.py
def delete_registered_model_permissions(self, name: str) -> None:
    """Bulk-delete all permissions for a registered model (all users)."""
    gsi4pk = f"{GSI4_PERM_PREFIX}MODEL#default#{name}"
    items = self._table.query(pk=gsi4pk, index_name=GSI4_NAME)
    for item in items:
        self._table.delete_item(pk=item["PK"], sk=item["SK"])

rename_registered_model_permissions

rename_registered_model_permissions(old_name, new_name)

Rename all permissions from old_name to new_name.

Source code in src/mlflow_dynamodbstore/auth/store.py
def rename_registered_model_permissions(self, old_name: str, new_name: str) -> None:
    """Rename all permissions from old_name to new_name."""
    gsi4pk = f"{GSI4_PERM_PREFIX}MODEL#default#{old_name}"
    items = self._table.query(pk=gsi4pk, index_name=GSI4_NAME)
    for item in items:
        # Delete old
        self._table.delete_item(pk=item["PK"], sk=item["SK"])
        # Write new
        username = item["PK"].removeprefix(PK_USER_PREFIX)
        new_sk = f"{SK_USER_PERM_PREFIX}MODEL#default#{new_name}"
        new_item: dict[str, Any] = {
            "PK": item["PK"],
            "SK": new_sk,
            "permission": item["permission"],
            GSI4_PK: f"{GSI4_PERM_PREFIX}MODEL#default#{new_name}",
            GSI4_SK: f"{PK_USER_PREFIX}{username}",
        }
        self._table.put_item(new_item)

set_workspace_permission

set_workspace_permission(workspace, username, permission)

Create or update a workspace permission (upsert).

Source code in src/mlflow_dynamodbstore/auth/store.py
def set_workspace_permission(self, workspace: str, username: str, permission: str) -> None:
    """Create or update a workspace permission (upsert)."""
    pk = f"{PK_USER_PREFIX}{username}"
    sk = f"{SK_USER_PERM_PREFIX}WORKSPACE#{workspace}"
    item: dict[str, Any] = {
        "PK": pk,
        "SK": sk,
        "permission": permission,
        GSI4_PK: f"{GSI4_PERM_PREFIX}WORKSPACE#{workspace}",
        GSI4_SK: f"{PK_USER_PREFIX}{username}",
    }
    self._table.put_item(item)  # No condition — upsert

get_workspace_permission

get_workspace_permission(workspace, username)

Return a workspace permission.

Raises MlflowException if the permission does not exist.

Source code in src/mlflow_dynamodbstore/auth/store.py
def get_workspace_permission(self, workspace: str, username: str) -> WorkspacePermission:
    """Return a workspace permission.

    Raises MlflowException if the permission does not exist.
    """
    pk = f"{PK_USER_PREFIX}{username}"
    sk = f"{SK_USER_PERM_PREFIX}WORKSPACE#{workspace}"
    item = self._table.get_item(pk=pk, sk=sk, consistent=True)
    if item is None:
        raise MlflowException(  # type: ignore[no-untyped-call]
            f"Permission for workspace '{workspace}' and user '{username}' not found.",
            error_code=RESOURCE_DOES_NOT_EXIST,
        )
    return WorkspacePermission(  # type: ignore[no-untyped-call]
        workspace=workspace,
        user_id=_user_id_from_username(username),
        permission=item["permission"],
    )

list_workspace_permissions

list_workspace_permissions(workspace)

Return all permissions for a workspace (all users).

Source code in src/mlflow_dynamodbstore/auth/store.py
def list_workspace_permissions(self, workspace: str) -> list[WorkspacePermission]:
    """Return all permissions for a workspace (all users)."""
    gsi4pk = f"{GSI4_PERM_PREFIX}WORKSPACE#{workspace}"
    items = self._table.query(pk=gsi4pk, index_name=GSI4_NAME)
    return [
        WorkspacePermission(  # type: ignore[no-untyped-call]
            workspace=workspace,
            user_id=_user_id_from_username(item[GSI4_SK].removeprefix(PK_USER_PREFIX)),
            permission=item["permission"],
        )
        for item in items
    ]

list_user_workspace_permissions

list_user_workspace_permissions(username)

Return all workspace permissions for a user.

Source code in src/mlflow_dynamodbstore/auth/store.py
def list_user_workspace_permissions(self, username: str) -> list[WorkspacePermission]:
    """Return all workspace permissions for a user."""
    pk = f"{PK_USER_PREFIX}{username}"
    items = self._table.query(pk=pk, sk_prefix=f"{SK_USER_PERM_PREFIX}WORKSPACE#")
    user_id = _user_id_from_username(username)
    return [
        WorkspacePermission(  # type: ignore[no-untyped-call]
            workspace=item["SK"].removeprefix(f"{SK_USER_PERM_PREFIX}WORKSPACE#"),
            user_id=user_id,
            permission=item["permission"],
        )
        for item in items
    ]

delete_workspace_permission

delete_workspace_permission(workspace, username)

Delete a workspace permission.

Source code in src/mlflow_dynamodbstore/auth/store.py
def delete_workspace_permission(self, workspace: str, username: str) -> None:
    """Delete a workspace permission."""
    pk = f"{PK_USER_PREFIX}{username}"
    sk = f"{SK_USER_PERM_PREFIX}WORKSPACE#{workspace}"
    self._table.delete_item(pk=pk, sk=sk)

delete_workspace_permissions_for_workspace

delete_workspace_permissions_for_workspace(workspace)

Bulk-delete all permissions for a workspace (all users).

Source code in src/mlflow_dynamodbstore/auth/store.py
def delete_workspace_permissions_for_workspace(self, workspace: str) -> None:
    """Bulk-delete all permissions for a workspace (all users)."""
    gsi4pk = f"{GSI4_PERM_PREFIX}WORKSPACE#{workspace}"
    items = self._table.query(pk=gsi4pk, index_name=GSI4_NAME)
    for item in items:
        self._table.delete_item(pk=item["PK"], sk=item["SK"])

list_accessible_workspace_names

list_accessible_workspace_names(username)

Return workspace names that the user has any permission on.

Source code in src/mlflow_dynamodbstore/auth/store.py
def list_accessible_workspace_names(self, username: str) -> list[str]:
    """Return workspace names that the user has any permission on."""
    pk = f"{PK_USER_PREFIX}{username}"
    items = self._table.query(pk=pk, sk_prefix=f"{SK_USER_PERM_PREFIX}WORKSPACE#")
    return [item["SK"].removeprefix(f"{SK_USER_PERM_PREFIX}WORKSPACE#") for item in items]

create_scorer_permission

create_scorer_permission(experiment_id, scorer_name, username, permission)

Create a scorer permission for a user.

Raises MlflowException if the permission already exists.

Source code in src/mlflow_dynamodbstore/auth/store.py
def create_scorer_permission(
    self,
    experiment_id: str,
    scorer_name: str,
    username: str,
    permission: str,
) -> ScorerPermission:
    """Create a scorer permission for a user.

    Raises MlflowException if the permission already exists.
    """
    pk = f"{PK_USER_PREFIX}{username}"
    sk = f"{SK_USER_PERM_PREFIX}SCORER#{experiment_id}#{scorer_name}"
    item: dict[str, Any] = {
        "PK": pk,
        "SK": sk,
        "permission": permission,
        GSI4_PK: f"{GSI4_PERM_PREFIX}SCORER#{experiment_id}#{scorer_name}",
        GSI4_SK: f"{PK_USER_PREFIX}{username}",
    }
    try:
        self._table.put_item(item, condition="attribute_not_exists(PK)")
    except Exception as exc:
        if "ConditionalCheckFailedException" in str(exc):
            raise MlflowException(  # type: ignore[no-untyped-call]
                f"Permission for scorer '{scorer_name}' in experiment "
                f"'{experiment_id}' and user '{username}' already exists.",
                error_code=RESOURCE_ALREADY_EXISTS,
            ) from exc
        raise

    return ScorerPermission(  # type: ignore[no-untyped-call]
        experiment_id=experiment_id,
        scorer_name=scorer_name,
        user_id=_user_id_from_username(username),
        permission=permission,
    )

get_scorer_permission

get_scorer_permission(experiment_id, scorer_name, username)

Return a scorer permission.

Raises MlflowException if the permission does not exist.

Source code in src/mlflow_dynamodbstore/auth/store.py
def get_scorer_permission(
    self, experiment_id: str, scorer_name: str, username: str
) -> ScorerPermission:
    """Return a scorer permission.

    Raises MlflowException if the permission does not exist.
    """
    pk = f"{PK_USER_PREFIX}{username}"
    sk = f"{SK_USER_PERM_PREFIX}SCORER#{experiment_id}#{scorer_name}"
    item = self._table.get_item(pk=pk, sk=sk, consistent=True)
    if item is None:
        raise MlflowException(  # type: ignore[no-untyped-call]
            f"Permission for scorer '{scorer_name}' in experiment "
            f"'{experiment_id}' and user '{username}' not found.",
            error_code=RESOURCE_DOES_NOT_EXIST,
        )
    return ScorerPermission(  # type: ignore[no-untyped-call]
        experiment_id=experiment_id,
        scorer_name=scorer_name,
        user_id=_user_id_from_username(username),
        permission=item["permission"],
    )

list_scorer_permissions

list_scorer_permissions(username)

Return all scorer permissions for a user.

Source code in src/mlflow_dynamodbstore/auth/store.py
def list_scorer_permissions(self, username: str) -> list[ScorerPermission]:
    """Return all scorer permissions for a user."""
    pk = f"{PK_USER_PREFIX}{username}"
    items = self._table.query(pk=pk, sk_prefix=f"{SK_USER_PERM_PREFIX}SCORER#")
    user_id = _user_id_from_username(username)
    result = []
    for item in items:
        suffix = item["SK"].removeprefix(f"{SK_USER_PERM_PREFIX}SCORER#")
        # suffix is "<experiment_id>#<scorer_name>"
        experiment_id, scorer_name = suffix.split("#", 1)
        result.append(
            ScorerPermission(  # type: ignore[no-untyped-call]
                experiment_id=experiment_id,
                scorer_name=scorer_name,
                user_id=user_id,
                permission=item["permission"],
            )
        )
    return result

update_scorer_permission

update_scorer_permission(experiment_id, scorer_name, username, permission)

Update a scorer permission.

Source code in src/mlflow_dynamodbstore/auth/store.py
def update_scorer_permission(
    self,
    experiment_id: str,
    scorer_name: str,
    username: str,
    permission: str,
) -> None:
    """Update a scorer permission."""
    pk = f"{PK_USER_PREFIX}{username}"
    sk = f"{SK_USER_PERM_PREFIX}SCORER#{experiment_id}#{scorer_name}"
    self._table.update_item(pk=pk, sk=sk, updates={"permission": permission})

delete_scorer_permission

delete_scorer_permission(experiment_id, scorer_name, username)

Delete a scorer permission.

Source code in src/mlflow_dynamodbstore/auth/store.py
def delete_scorer_permission(self, experiment_id: str, scorer_name: str, username: str) -> None:
    """Delete a scorer permission."""
    pk = f"{PK_USER_PREFIX}{username}"
    sk = f"{SK_USER_PERM_PREFIX}SCORER#{experiment_id}#{scorer_name}"
    self._table.delete_item(pk=pk, sk=sk)

delete_scorer_permissions_for_scorer

delete_scorer_permissions_for_scorer(experiment_id, scorer_name)

Bulk-delete all permissions for a scorer (all users).

Source code in src/mlflow_dynamodbstore/auth/store.py
def delete_scorer_permissions_for_scorer(self, experiment_id: str, scorer_name: str) -> None:
    """Bulk-delete all permissions for a scorer (all users)."""
    gsi4pk = f"{GSI4_PERM_PREFIX}SCORER#{experiment_id}#{scorer_name}"
    items = self._table.query(pk=gsi4pk, index_name=GSI4_NAME)
    for item in items:
        self._table.delete_item(pk=item["PK"], sk=item["SK"])

X-Ray Client

client

XRayClient

XRayClient(region=None, endpoint_url=None)

Wrapper around boto3 X-Ray client with automatic chunking and pagination.

Source code in src/mlflow_dynamodbstore/xray/client.py
def __init__(self, region: str | None = None, endpoint_url: str | None = None):
    kwargs: dict[str, Any] = {}
    if region:
        kwargs["region_name"] = region
    if endpoint_url:
        kwargs["endpoint_url"] = endpoint_url
    self._client = boto3.client("xray", **kwargs)

get_trace_summaries

get_trace_summaries(start_time, end_time, filter_expression=None)

Get trace summaries with automatic time window chunking and pagination.

X-Ray limits GetTraceSummaries to 6-hour windows, so longer ranges are automatically split into consecutive chunks.

Source code in src/mlflow_dynamodbstore/xray/client.py
def get_trace_summaries(
    self,
    start_time: datetime.datetime,
    end_time: datetime.datetime,
    filter_expression: str | None = None,
) -> list[dict[str, Any]]:
    """Get trace summaries with automatic time window chunking and pagination.

    X-Ray limits GetTraceSummaries to 6-hour windows, so longer ranges are
    automatically split into consecutive chunks.
    """
    summaries: list[dict[str, Any]] = []
    chunk_start = start_time
    while chunk_start < end_time:
        chunk_end = min(chunk_start + datetime.timedelta(hours=_MAX_WINDOW_HOURS), end_time)
        kwargs: dict[str, Any] = {"StartTime": chunk_start, "EndTime": chunk_end}
        if filter_expression:
            kwargs["FilterExpression"] = filter_expression
        # Paginate within chunk
        while True:
            response = self._client.get_trace_summaries(**kwargs)
            summaries.extend(response.get("TraceSummaries", []))
            next_token = response.get("NextToken")
            if not next_token:
                break
            kwargs["NextToken"] = next_token
        chunk_start = chunk_end
    return summaries

batch_get_traces

batch_get_traces(trace_ids)

Get full traces, batching in groups of 5 (X-Ray API limit).

Source code in src/mlflow_dynamodbstore/xray/client.py
def batch_get_traces(self, trace_ids: list[str]) -> list[dict[str, Any]]:
    """Get full traces, batching in groups of 5 (X-Ray API limit)."""
    traces: list[dict[str, Any]] = []
    for i in range(0, len(trace_ids), _MAX_BATCH_SIZE):
        batch = trace_ids[i : i + _MAX_BATCH_SIZE]
        response = self._client.batch_get_traces(TraceIds=batch)
        traces.extend(response.get("Traces", []))
    return traces

DynamoDB Table

DynamoDBTable

DynamoDBTable(table_name, region=None, endpoint_url=None)

High-level DynamoDB table client using boto3 resource API.

Source code in src/mlflow_dynamodbstore/dynamodb/table.py
def __init__(
    self,
    table_name: str,
    region: str | None = None,
    endpoint_url: str | None = None,
) -> None:
    kwargs: dict[str, Any] = {}
    if region:
        kwargs["region_name"] = region
    if endpoint_url:
        kwargs["endpoint_url"] = endpoint_url
    resource = boto3.resource("dynamodb", **kwargs)
    self._table = resource.Table(table_name)

put_item

put_item(item, condition=None)

Write an item, with optional ConditionExpression.

Source code in src/mlflow_dynamodbstore/dynamodb/table.py
def put_item(self, item: dict[str, Any], condition: str | None = None) -> None:
    """Write an item, with optional ConditionExpression."""
    _validate_index_key_types(item)
    item = convert_floats(item)
    kwargs: dict[str, Any] = {"Item": item}
    if condition:
        kwargs["ConditionExpression"] = condition
    self._table.put_item(**kwargs)

get_item

get_item(pk, sk, consistent=False)

Return item by PK+SK, or None if not found.

Source code in src/mlflow_dynamodbstore/dynamodb/table.py
def get_item(self, pk: str, sk: str, consistent: bool = False) -> dict[str, Any] | None:
    """Return item by PK+SK, or None if not found."""
    kwargs: dict[str, Any] = {
        "Key": {"PK": pk, "SK": sk},
        "ConsistentRead": consistent,
    }
    response = self._table.get_item(**kwargs)
    item: dict[str, Any] | None = response.get("Item")
    return item

delete_item

delete_item(pk, sk)

Delete item by PK+SK.

Source code in src/mlflow_dynamodbstore/dynamodb/table.py
def delete_item(self, pk: str, sk: str) -> None:
    """Delete item by PK+SK."""
    self._table.delete_item(Key={"PK": pk, "SK": sk})

update_item

update_item(pk, sk, updates=None, removes=None, condition=None)

Update attributes on an item using SET and/or REMOVE expressions.

Source code in src/mlflow_dynamodbstore/dynamodb/table.py
def update_item(
    self,
    pk: str,
    sk: str,
    updates: dict[str, Any] | None = None,
    removes: list[str] | None = None,
    condition: str | None = None,
) -> dict[str, Any] | None:
    """Update attributes on an item using SET and/or REMOVE expressions."""
    expression_parts: list[str] = []
    expr_names: dict[str, str] = {}
    expr_values: dict[str, Any] = {}

    if updates:
        _validate_index_key_types(updates)
        set_clauses = []
        for i, (attr, val) in enumerate(updates.items()):
            name_token = f"#u{i}"
            value_token = f":u{i}"
            expr_names[name_token] = attr
            expr_values[value_token] = val
            set_clauses.append(f"{name_token} = {value_token}")
        expression_parts.append("SET " + ", ".join(set_clauses))

    if removes:
        remove_tokens = []
        for i, attr in enumerate(removes):
            name_token = f"#r{i}"
            expr_names[name_token] = attr
            remove_tokens.append(name_token)
        expression_parts.append("REMOVE " + ", ".join(remove_tokens))

    if not expression_parts:
        return None

    kwargs: dict[str, Any] = {
        "Key": {"PK": pk, "SK": sk},
        "UpdateExpression": " ".join(expression_parts),
        "ReturnValues": "ALL_NEW",
    }
    if expr_names:
        kwargs["ExpressionAttributeNames"] = expr_names
    if expr_values:
        kwargs["ExpressionAttributeValues"] = expr_values
    if condition:
        kwargs["ConditionExpression"] = condition

    response = self._table.update_item(**kwargs)
    attributes: dict[str, Any] | None = response.get("Attributes")
    return attributes

add_attribute

add_attribute(pk, sk, attribute, value)

Atomically increment an attribute using ADD expression. Returns updated item.

Source code in src/mlflow_dynamodbstore/dynamodb/table.py
def add_attribute(
    self,
    pk: str,
    sk: str,
    attribute: str,
    value: int | float,
) -> dict[str, Any]:
    """Atomically increment an attribute using ADD expression. Returns updated item."""
    response = self._table.update_item(
        Key={"PK": pk, "SK": sk},
        UpdateExpression="ADD #attr :val",
        ExpressionAttributeNames={"#attr": attribute},
        ExpressionAttributeValues={":val": value},
        ReturnValues="UPDATED_NEW",
    )
    attributes: dict[str, Any] = response.get("Attributes", {})
    return attributes

query

query(pk, sk_prefix=None, sk_gte=None, sk_lte=None, index_name=None, limit=None, scan_forward=True, consistent=False, filter_expression=None)

Query the table or an index with flexible key conditions.

For index queries the PK attribute name is derived from index_name (e.g. "gsi1" -> "gsi1pk", "lsi1" -> "PK").

Source code in src/mlflow_dynamodbstore/dynamodb/table.py
def query(
    self,
    pk: str,
    sk_prefix: str | None = None,
    sk_gte: str | None = None,
    sk_lte: str | None = None,
    index_name: str | None = None,
    limit: int | None = None,
    scan_forward: bool = True,
    consistent: bool = False,
    filter_expression: ConditionBase | None = None,
) -> list[dict[str, Any]]:
    """Query the table or an index with flexible key conditions.

    For index queries the PK attribute name is derived from index_name
    (e.g. "gsi1" -> "gsi1pk", "lsi1" -> "PK").
    """
    # Determine key attribute names
    if index_name:
        pk_attr, sk_attr = _INDEX_KEY_ATTRS[index_name]
    else:
        pk_attr, sk_attr = "PK", "SK"

    # Build key condition
    key_cond: ConditionBase = Key(pk_attr).eq(pk)
    if sk_prefix:
        key_cond = key_cond & Key(sk_attr).begins_with(sk_prefix)
    elif sk_gte is not None and sk_lte is not None:
        key_cond = key_cond & Key(sk_attr).between(sk_gte, sk_lte)

    kwargs: dict[str, Any] = {
        "KeyConditionExpression": key_cond,
        "ScanIndexForward": scan_forward,
        "ConsistentRead": consistent,
    }
    if index_name:
        kwargs["IndexName"] = index_name
    if filter_expression is not None:
        kwargs["FilterExpression"] = filter_expression

    # Collect items, handling pagination
    items: list[dict[str, Any]] = []
    remaining = limit  # None means unlimited

    while True:
        if remaining is not None:
            kwargs["Limit"] = remaining

        response = self._table.query(**kwargs)
        batch: list[dict[str, Any]] = response.get("Items", [])
        items.extend(batch)

        if remaining is not None:
            remaining -= len(batch)
            if remaining <= 0:
                break

        last_key = response.get("LastEvaluatedKey")
        if not last_key:
            break
        kwargs["ExclusiveStartKey"] = last_key

    return items

batch_write

batch_write(items)

Batch write items, chunking into groups of 25.

Source code in src/mlflow_dynamodbstore/dynamodb/table.py
def batch_write(self, items: list[dict[str, Any]]) -> None:
    """Batch write items, chunking into groups of 25."""
    for item in items:
        _validate_index_key_types(item)
    with self._table.batch_writer() as batch:
        for item in items:
            batch.put_item(Item=convert_floats(item))

batch_delete

batch_delete(keys)

Batch delete items by PK+SK key dicts.

Source code in src/mlflow_dynamodbstore/dynamodb/table.py
def batch_delete(self, keys: list[dict[str, Any]]) -> None:
    """Batch delete items by PK+SK key dicts."""
    with self._table.batch_writer() as batch:
        for key in keys:
            batch.delete_item(Key=key)

query_page

query_page(pk, sk_prefix=None, index_name=None, limit=None, scan_forward=True, consistent=False, exclusive_start_key=None, filter_expression=None)

Query a single page, returning (items, last_evaluated_key).

Unlike query(), this does NOT auto-exhaust pagination. Returns the raw LastEvaluatedKey for caller-managed cursors.

Source code in src/mlflow_dynamodbstore/dynamodb/table.py
def query_page(
    self,
    pk: str,
    sk_prefix: str | None = None,
    index_name: str | None = None,
    limit: int | None = None,
    scan_forward: bool = True,
    consistent: bool = False,
    exclusive_start_key: dict[str, Any] | None = None,
    filter_expression: ConditionBase | None = None,
) -> tuple[list[dict[str, Any]], dict[str, Any] | None]:
    """Query a single page, returning (items, last_evaluated_key).

    Unlike query(), this does NOT auto-exhaust pagination.
    Returns the raw LastEvaluatedKey for caller-managed cursors.
    """
    if index_name:
        pk_attr, sk_attr = _INDEX_KEY_ATTRS[index_name]
    else:
        pk_attr, sk_attr = "PK", "SK"

    key_cond: ConditionBase = Key(pk_attr).eq(pk)
    if sk_prefix:
        key_cond = key_cond & Key(sk_attr).begins_with(sk_prefix)

    kwargs: dict[str, Any] = {
        "KeyConditionExpression": key_cond,
        "ScanIndexForward": scan_forward,
        "ConsistentRead": consistent,
    }
    if index_name:
        kwargs["IndexName"] = index_name
    if limit is not None:
        kwargs["Limit"] = limit
    if exclusive_start_key is not None:
        kwargs["ExclusiveStartKey"] = exclusive_start_key
    if filter_expression is not None:
        kwargs["FilterExpression"] = filter_expression

    response = self._table.query(**kwargs)
    items: list[dict[str, Any]] = response.get("Items", [])
    lek: dict[str, Any] | None = response.get("LastEvaluatedKey")
    return items, lek

Config Reader

ConfigReader

ConfigReader(table)

Reads CONFIG items from DynamoDB, caches in memory.

Provides: - should_denormalize(experiment_id, tag_key): matches tag_key against patterns - should_trigram(field_type): entity name fields always return True; others are configurable - reconcile(): reads env vars and merges with defaults, persisting to DynamoDB

Source code in src/mlflow_dynamodbstore/dynamodb/config.py
def __init__(self, table: DynamoDBTable) -> None:
    self._table = table
    # In-memory caches
    self._denormalize_patterns: list[str] | None = None
    self._fts_trigram_fields: list[str] | None = None
    self._ttl_policy: dict[str, int] | None = None
    # Per-experiment denormalize pattern cache: experiment_id -> list[str]
    self._exp_denormalize_patterns: dict[str, list[str]] = {}

get_denormalize_patterns

get_denormalize_patterns()

Return the global denormalize tag patterns, loading from DynamoDB if needed.

Source code in src/mlflow_dynamodbstore/dynamodb/config.py
def get_denormalize_patterns(self) -> list[str]:
    """Return the global denormalize tag patterns, loading from DynamoDB if needed."""
    if self._denormalize_patterns is None:
        self._denormalize_patterns = self._load_denormalize_patterns()
    return self._denormalize_patterns

set_denormalize_patterns

set_denormalize_patterns(patterns)

Persist global denormalize patterns to DynamoDB and update cache.

Source code in src/mlflow_dynamodbstore/dynamodb/config.py
def set_denormalize_patterns(self, patterns: list[str]) -> None:
    """Persist global denormalize patterns to DynamoDB and update cache."""
    # Always ensure mlflow.* is present
    merged = list(patterns)
    if "mlflow.*" not in merged:
        merged.insert(0, "mlflow.*")
    self._table.put_item(
        {
            "PK": PK_CONFIG,
            "SK": CONFIG_DENORMALIZE_TAGS,
            "patterns": merged,
        }
    )
    self._denormalize_patterns = merged

set_experiment_denormalize_patterns

set_experiment_denormalize_patterns(experiment_id, patterns)

Persist per-experiment denormalize patterns to DynamoDB and update cache.

Source code in src/mlflow_dynamodbstore/dynamodb/config.py
def set_experiment_denormalize_patterns(self, experiment_id: str, patterns: list[str]) -> None:
    """Persist per-experiment denormalize patterns to DynamoDB and update cache."""
    sk = f"{_SK_EXP_DENORMALIZE_PREFIX}{experiment_id}"
    self._table.put_item(
        {
            "PK": PK_CONFIG,
            "SK": sk,
            "patterns": list(patterns),
        }
    )
    self._exp_denormalize_patterns[experiment_id] = list(patterns)

get_experiment_denormalize_patterns

get_experiment_denormalize_patterns(experiment_id)

Return per-experiment denormalize patterns (not merged with global).

Source code in src/mlflow_dynamodbstore/dynamodb/config.py
def get_experiment_denormalize_patterns(self, experiment_id: str) -> list[str]:
    """Return per-experiment denormalize patterns (not merged with global)."""
    if experiment_id not in self._exp_denormalize_patterns:
        sk = f"{_SK_EXP_DENORMALIZE_PREFIX}{experiment_id}"
        item = self._table.get_item(pk=PK_CONFIG, sk=sk)
        if item and "patterns" in item:
            self._exp_denormalize_patterns[experiment_id] = list(item["patterns"])
        else:
            self._exp_denormalize_patterns[experiment_id] = []
    return self._exp_denormalize_patterns[experiment_id]

get_effective_denormalize_patterns

get_effective_denormalize_patterns(experiment_id)

Return merged global + per-experiment denormalize patterns.

Source code in src/mlflow_dynamodbstore/dynamodb/config.py
def get_effective_denormalize_patterns(self, experiment_id: str) -> list[str]:
    """Return merged global + per-experiment denormalize patterns."""
    global_patterns = self.get_denormalize_patterns()
    exp_patterns = self.get_experiment_denormalize_patterns(experiment_id)
    # Merge, preserving order, deduplicating
    merged: list[str] = list(global_patterns)
    for p in exp_patterns:
        if p not in merged:
            merged.append(p)
    return merged

should_denormalize

should_denormalize(experiment_id, tag_key)

Return True if tag_key matches any effective denormalize pattern.

Source code in src/mlflow_dynamodbstore/dynamodb/config.py
def should_denormalize(self, experiment_id: str | None, tag_key: str) -> bool:
    """Return True if tag_key matches any effective denormalize pattern."""
    if experiment_id is not None:
        patterns = self.get_effective_denormalize_patterns(experiment_id)
    else:
        patterns = self.get_denormalize_patterns()
    return any(fnmatch.fnmatch(tag_key, pattern) for pattern in patterns)

get_fts_trigram_fields

get_fts_trigram_fields()

Return configurable FTS trigram fields (not including always-trigram fields).

Source code in src/mlflow_dynamodbstore/dynamodb/config.py
def get_fts_trigram_fields(self) -> list[str]:
    """Return configurable FTS trigram fields (not including always-trigram fields)."""
    if self._fts_trigram_fields is None:
        self._fts_trigram_fields = self._load_fts_trigram_fields()
    return self._fts_trigram_fields

set_fts_trigram_fields

set_fts_trigram_fields(fields)

Persist FTS trigram fields to DynamoDB and update cache.

Source code in src/mlflow_dynamodbstore/dynamodb/config.py
def set_fts_trigram_fields(self, fields: list[str]) -> None:
    """Persist FTS trigram fields to DynamoDB and update cache."""
    self._table.put_item(
        {
            "PK": PK_CONFIG,
            "SK": CONFIG_FTS_TRIGRAM_FIELDS,
            "fields": list(fields),
        }
    )
    self._fts_trigram_fields = list(fields)

should_trigram

should_trigram(field_type)

Return True if field_type should have trigram indexing.

Entity name fields (experiment_name, run_name, model_name) always return True. Other fields are checked against the configured FTS trigram fields list.

Source code in src/mlflow_dynamodbstore/dynamodb/config.py
def should_trigram(self, field_type: str) -> bool:
    """Return True if field_type should have trigram indexing.

    Entity name fields (experiment_name, run_name, model_name) always return True.
    Other fields are checked against the configured FTS trigram fields list.
    """
    if field_type in _ALWAYS_TRIGRAM_FIELDS:
        return True
    return field_type in self.get_fts_trigram_fields()

get_ttl_policy

get_ttl_policy()

Return the TTL policy, loading from DynamoDB if needed.

Source code in src/mlflow_dynamodbstore/dynamodb/config.py
def get_ttl_policy(self) -> dict[str, int]:
    """Return the TTL policy, loading from DynamoDB if needed."""
    if self._ttl_policy is None:
        self._ttl_policy = self._load_ttl_policy()
    return self._ttl_policy

set_ttl_policy

set_ttl_policy(**kwargs)

Update individual fields in the TTL policy and persist to DynamoDB.

Source code in src/mlflow_dynamodbstore/dynamodb/config.py
def set_ttl_policy(self, **kwargs: int) -> None:
    """Update individual fields in the TTL policy and persist to DynamoDB."""
    policy = self.get_ttl_policy()
    for key, value in kwargs.items():
        if key not in self._DEFAULT_TTL_POLICY:
            raise ValueError(f"Unknown TTL policy field: {key}")
        policy[key] = int(value)
    self._table.put_item(
        {
            "PK": PK_CONFIG,
            "SK": CONFIG_TTL_POLICY,
            **policy,
        }
    )
    self._ttl_policy = policy

get_trace_ttl_seconds

get_trace_ttl_seconds()

Return trace retention in seconds, or None if disabled (0 days).

Source code in src/mlflow_dynamodbstore/dynamodb/config.py
def get_trace_ttl_seconds(self) -> int | None:
    """Return trace retention in seconds, or None if disabled (0 days)."""
    return self._ttl_seconds("trace_retention_days")

get_soft_deleted_ttl_seconds

get_soft_deleted_ttl_seconds()

Return soft-deleted retention in seconds, or None if disabled (0 days).

Source code in src/mlflow_dynamodbstore/dynamodb/config.py
def get_soft_deleted_ttl_seconds(self) -> int | None:
    """Return soft-deleted retention in seconds, or None if disabled (0 days)."""
    return self._ttl_seconds("soft_deleted_retention_days")

get_metric_history_ttl_seconds

get_metric_history_ttl_seconds()

Return metric history retention in seconds, or None if disabled (0 days).

Source code in src/mlflow_dynamodbstore/dynamodb/config.py
def get_metric_history_ttl_seconds(self) -> int | None:
    """Return metric history retention in seconds, or None if disabled (0 days)."""
    return self._ttl_seconds("metric_history_retention_days")

reconcile

reconcile()

Read env vars and merge with defaults, persisting updated config to DynamoDB.

Source code in src/mlflow_dynamodbstore/dynamodb/config.py
def reconcile(self) -> None:
    """Read env vars and merge with defaults, persisting updated config to DynamoDB."""
    env_value = os.environ.get(_ENV_DENORMALIZE_TAGS)
    if env_value:
        # Parse comma-separated patterns from env var
        env_patterns = [p.strip() for p in env_value.split(",") if p.strip()]
        # Merge with current patterns, ensuring mlflow.* is always present
        current = self.get_denormalize_patterns()
        merged: list[str] = list(current)
        for p in env_patterns:
            if p not in merged:
                merged.append(p)
        if "mlflow.*" not in merged:
            merged.insert(0, "mlflow.*")
        self.set_denormalize_patterns(merged)
    else:
        # Ensure defaults are persisted (idempotent)
        patterns = self.get_denormalize_patterns()
        self.set_denormalize_patterns(patterns)

    # Reconcile TTL policy from environment variables
    ttl_env_mapping = {
        _ENV_SOFT_DELETED_RETENTION_DAYS: "soft_deleted_retention_days",
        _ENV_TRACE_RETENTION_DAYS: "trace_retention_days",
        _ENV_METRIC_HISTORY_RETENTION_DAYS: "metric_history_retention_days",
    }
    ttl_overrides: dict[str, int] = {}
    for env_var, policy_key in ttl_env_mapping.items():
        env_val = os.environ.get(env_var)
        if env_val is not None:
            ttl_overrides[policy_key] = int(env_val)
    if ttl_overrides:
        self.set_ttl_policy(**ttl_overrides)

fts

Full-text search tokenizers: word-level (stemmed) and trigram-level.

tokenize_words

tokenize_words(text)

Stemmed whole-word tokens for LIKE '%complete_word%' matches.

Source code in src/mlflow_dynamodbstore/dynamodb/fts.py
def tokenize_words(text: str) -> set[str]:
    """Stemmed whole-word tokens for LIKE '%complete_word%' matches."""
    words = re.findall(r"[a-z0-9]+", text.lower())
    words = [w for w in words if w not in STOP_WORDS and len(w) > 1]
    return set(_stemmer.stemWords(words))

tokenize_trigrams

tokenize_trigrams(text)

Character trigrams for LIKE '%partial%' matches.

Source code in src/mlflow_dynamodbstore/dynamodb/fts.py
def tokenize_trigrams(text: str) -> set[str]:
    """Character trigrams for LIKE '%partial%' matches."""
    words = re.findall(r"[a-z0-9]+", text.lower())
    grams: set[str] = set()
    for word in words:
        for i in range(len(word) - 2):
            grams.add(word[i : i + 3])
    return grams

tokenize_tail_bigrams

tokenize_tail_bigrams(text)

Last 2 characters of each word — covers end-of-word bigram positions.

Source code in src/mlflow_dynamodbstore/dynamodb/fts.py
def tokenize_tail_bigrams(text: str) -> set[str]:
    """Last 2 characters of each word — covers end-of-word bigram positions."""
    words = re.findall(r"[a-z0-9]+", text.lower())
    return {word[-2:] for word in words if len(word) >= 2}

tokenize_bigrams

tokenize_bigrams(text)

All character bigrams of the search term (query-side only).

Source code in src/mlflow_dynamodbstore/dynamodb/fts.py
def tokenize_bigrams(text: str) -> set[str]:
    """All character bigrams of the search term (query-side only)."""
    words = re.findall(r"[a-z0-9]+", text.lower())
    grams: set[str] = set()
    for word in words:
        for i in range(len(word) - 1):
            grams.add(word[i : i + 2])
    return grams

fts_items_for_text

fts_items_for_text(pk, entity_type, entity_id, field, text, levels=('W', '3', '2'), workspace=None)

Return DynamoDB item dicts (forward + reverse) for every FTS token.

Key patterns ~~~~~~~~~~~~ Forward SK : FTS#<level>#<entity_type>#<token>#<entity_id>[#<field>] Reverse SK : FTS_REV#<entity_type>#<entity_id>[#<field>]#<level>#<token>

For experiment (entity_type="E") and model (entity_type="M") names the forward items also carry gsi2pk / gsi2sk so that a single GSI2 query can search across partitions.

Source code in src/mlflow_dynamodbstore/dynamodb/fts.py
def fts_items_for_text(
    pk: str,
    entity_type: str,
    entity_id: str,
    field: str | None,
    text: str,
    levels: tuple[str, ...] = ("W", "3", "2"),
    workspace: str | None = None,
) -> list[dict[str, Any]]:
    """Return DynamoDB item dicts (forward + reverse) for every FTS token.

    Key patterns
    ~~~~~~~~~~~~
    Forward SK : ``FTS#<level>#<entity_type>#<token>#<entity_id>[#<field>]``
    Reverse SK : ``FTS_REV#<entity_type>#<entity_id>[#<field>]#<level>#<token>``

    For experiment (``entity_type="E"``) and model (``entity_type="M"``)
    names the forward items also carry ``gsi2pk`` / ``gsi2sk`` so that a
    single GSI2 query can search across partitions.
    """
    add_gsi2 = entity_type in _GSI2_ENTITY_TYPES

    # Build the optional field suffix once.
    field_suffix = f"#{field}" if field else ""

    items: list[dict[str, Any]] = []

    for level in levels:
        tokens = _tokens_for_level(level, text)
        for token in tokens:
            forward_sk = f"{SK_FTS_PREFIX}{level}#{entity_type}#{token}#{entity_id}{field_suffix}"
            reverse_sk = (
                f"{SK_FTS_REV_PREFIX}{entity_type}#{entity_id}{field_suffix}#{level}#{token}"
            )

            forward: dict[str, Any] = {"PK": pk, "SK": forward_sk}
            reverse: dict[str, Any] = {"PK": pk, "SK": reverse_sk}

            if add_gsi2:
                gsi2pk_val = f"{GSI2_FTS_NAMES_PREFIX}{workspace}"
                gsi2sk_val = f"{level}#{entity_type}#{token}#{entity_id}{field_suffix}"
                forward[GSI2_PK] = gsi2pk_val
                forward[GSI2_SK] = gsi2sk_val

            items.append(forward)
            items.append(reverse)

    return items

fts_diff

fts_diff(old_text, new_text, levels=('W', '3', '2'))

Compute the token-level diff between old_text and new_text.

Returns (tokens_to_add, tokens_to_remove) where each element is a set of (level, token) tuples. Common tokens appear in neither set.

Source code in src/mlflow_dynamodbstore/dynamodb/fts.py
def fts_diff(
    old_text: str | None,
    new_text: str,
    levels: tuple[str, ...] = ("W", "3", "2"),
) -> tuple[set[tuple[str, str]], set[tuple[str, str]]]:
    """Compute the token-level diff between *old_text* and *new_text*.

    Returns ``(tokens_to_add, tokens_to_remove)`` where each element is a
    ``set`` of ``(level, token)`` tuples.  Common tokens appear in neither
    set.
    """
    old_tokens: set[tuple[str, str]] = set()
    new_tokens: set[tuple[str, str]] = set()

    for level in levels:
        if old_text is not None:
            for token in _tokens_for_level(level, old_text):
                old_tokens.add((level, token))
        for token in _tokens_for_level(level, new_text):
            new_tokens.add((level, token))

    tokens_to_add = new_tokens - old_tokens
    tokens_to_remove = old_tokens - new_tokens
    return tokens_to_add, tokens_to_remove