Skip to content
114 changes: 102 additions & 12 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import itertools
import os
import random
import time
import uuid
import warnings
from abc import ABC, abstractmethod
Expand All @@ -31,6 +33,7 @@
from pydantic import Field

import pyiceberg.expressions.parser as parser
from pyiceberg.exceptions import CommitFailedException
from pyiceberg.expressions import AlwaysFalse, AlwaysTrue, And, BooleanExpression, EqualTo, IsNull, Or, Reference
from pyiceberg.expressions.visitors import (
ResidualEvaluator,
Expand Down Expand Up @@ -205,6 +208,22 @@ class TableProperties:
MIN_SNAPSHOTS_TO_KEEP = "history.expire.min-snapshots-to-keep"
MIN_SNAPSHOTS_TO_KEEP_DEFAULT = 1

COMMIT_NUM_RETRIES = "commit.retry.num-retries"
COMMIT_NUM_RETRIES_DEFAULT = 4

COMMIT_MIN_RETRY_WAIT_MS = "commit.retry.min-wait-ms"
COMMIT_MIN_RETRY_WAIT_MS_DEFAULT = 100

COMMIT_MAX_RETRY_WAIT_MS = "commit.retry.max-wait-ms"
COMMIT_MAX_RETRY_WAIT_MS_DEFAULT = 60000

COMMIT_TOTAL_RETRY_TIME_MS = "commit.retry.total-timeout-ms"
COMMIT_TOTAL_RETRY_TIME_MS_DEFAULT = 1800000 # 30 minutes

WRITE_DELETE_ISOLATION_LEVEL = "write.delete.isolation-level"
WRITE_UPDATE_ISOLATION_LEVEL = "write.update.isolation-level"
WRITE_ISOLATION_LEVEL_DEFAULT = "serializable"


class Transaction:
_table: Table
Expand All @@ -223,6 +242,7 @@ def __init__(self, table: Table, autocommit: bool = False):
self._autocommit = autocommit
self._updates = ()
self._requirements = ()
self._snapshot_producers: list[Any] = []

@property
def table_metadata(self) -> TableMetadata:
Expand Down Expand Up @@ -265,6 +285,10 @@ def _stage(

return self

def _register_snapshot_producer(self, producer: Any) -> None:
"""Register a snapshot producer for retry support."""
self._snapshot_producers.append(producer)

def _apply(
self,
updates: tuple[TableUpdate, ...],
Expand Down Expand Up @@ -546,7 +570,12 @@ def dynamic_partition_overwrite(
delete_filter = self._build_partition_predicate(
partition_records=partitions_to_overwrite, spec=self.table_metadata.spec(), schema=self.table_metadata.schema()
)
self.delete(delete_filter=delete_filter, snapshot_properties=snapshot_properties, branch=branch)
self.delete(
delete_filter=delete_filter,
snapshot_properties=snapshot_properties,
branch=branch,
_isolation_level_property=TableProperties.WRITE_UPDATE_ISOLATION_LEVEL,
)

with self._append_snapshot_producer(snapshot_properties, branch=branch) as append_files:
append_files.commit_uuid = append_snapshot_commit_uuid
Expand Down Expand Up @@ -603,6 +632,7 @@ def overwrite(
case_sensitive=case_sensitive,
snapshot_properties=snapshot_properties,
branch=branch,
_isolation_level_property=TableProperties.WRITE_UPDATE_ISOLATION_LEVEL,
)

with self._append_snapshot_producer(snapshot_properties, branch=branch) as append_files:
Expand All @@ -620,6 +650,7 @@ def delete(
snapshot_properties: dict[str, str] = EMPTY_DICT,
case_sensitive: bool = True,
branch: str | None = MAIN_BRANCH,
_isolation_level_property: str | None = None,
) -> None:
"""
Shorthand for deleting record from a table.
Expand Down Expand Up @@ -647,6 +678,8 @@ def delete(
delete_filter = _parse_row_filter(delete_filter)

with self.update_snapshot(snapshot_properties=snapshot_properties, branch=branch).delete() as delete_snapshot:
if _isolation_level_property is not None:
delete_snapshot._isolation_level_property = _isolation_level_property
delete_snapshot.delete_by_predicate(delete_filter, case_sensitive)

# Check if there are any files that require an actual rewrite of a data file
Expand Down Expand Up @@ -702,7 +735,10 @@ def delete(
with self.update_snapshot(
snapshot_properties=snapshot_properties, branch=branch
).overwrite() as overwrite_snapshot:
if _isolation_level_property is not None:
overwrite_snapshot._isolation_level_property = _isolation_level_property
overwrite_snapshot.commit_uuid = commit_uuid
overwrite_snapshot.delete_by_predicate(delete_filter, case_sensitive)
for original_data_file, replaced_data_files in replaced_files:
overwrite_snapshot.delete_data_file(original_data_file)
for replaced_data_file in replaced_data_files:
Expand Down Expand Up @@ -939,17 +975,73 @@ def commit_transaction(self) -> Table:
The table with the updates applied.
"""
if len(self._updates) > 0:
self._requirements += (AssertTableUUID(uuid=self.table_metadata.table_uuid),)
self._table._do_commit( # pylint: disable=W0212
updates=self._updates,
requirements=self._requirements,
from pyiceberg.utils.properties import property_as_int

properties = self._table.metadata.properties
num_retries_val = property_as_int(
properties, TableProperties.COMMIT_NUM_RETRIES, TableProperties.COMMIT_NUM_RETRIES_DEFAULT
)
num_retries = num_retries_val if num_retries_val is not None else TableProperties.COMMIT_NUM_RETRIES_DEFAULT
min_wait_val = property_as_int(
properties, TableProperties.COMMIT_MIN_RETRY_WAIT_MS, TableProperties.COMMIT_MIN_RETRY_WAIT_MS_DEFAULT
)
min_wait_ms = min_wait_val if min_wait_val is not None else TableProperties.COMMIT_MIN_RETRY_WAIT_MS_DEFAULT
max_wait_val = property_as_int(
properties, TableProperties.COMMIT_MAX_RETRY_WAIT_MS, TableProperties.COMMIT_MAX_RETRY_WAIT_MS_DEFAULT
)
max_wait_ms = max_wait_val if max_wait_val is not None else TableProperties.COMMIT_MAX_RETRY_WAIT_MS_DEFAULT
total_timeout_val = property_as_int(
properties, TableProperties.COMMIT_TOTAL_RETRY_TIME_MS, TableProperties.COMMIT_TOTAL_RETRY_TIME_MS_DEFAULT
)
total_timeout_ms = (
total_timeout_val if total_timeout_val is not None else TableProperties.COMMIT_TOTAL_RETRY_TIME_MS_DEFAULT
)
start_time = time.monotonic()

for attempt in range(num_retries + 1):
try:
self._requirements += (AssertTableUUID(uuid=self.table_metadata.table_uuid),)
self._table._do_commit( # pylint: disable=W0212
updates=self._updates,
requirements=self._requirements,
)
self._cleanup_uncommitted_manifests()
break
except CommitFailedException:
elapsed_ms = (time.monotonic() - start_time) * 1000
if attempt == num_retries or not self._snapshot_producers or elapsed_ms >= total_timeout_ms:
raise

wait = min(min_wait_ms * (2**attempt), max_wait_ms)
jitter = random.uniform(0, 0.25 * wait)
time.sleep((wait + jitter) / 1000.0)

self._table.refresh()
self._rebuild_snapshot_updates()

self._updates = ()
self._requirements = ()

return self._table

def _cleanup_uncommitted_manifests(self) -> None:
"""Clean up manifests from failed retry attempts after a successful commit."""
for producer in self._snapshot_producers:
producer._cleanup_uncommitted()

def _rebuild_snapshot_updates(self) -> None:
"""Rebuild snapshot updates for retry by re-executing registered producers."""
from pyiceberg.table.update import AddSnapshotUpdate, AssertRefSnapshotId, SetSnapshotRefUpdate

self._updates = tuple(u for u in self._updates if not isinstance(u, (AddSnapshotUpdate, SetSnapshotRefUpdate)))
self._requirements = tuple(r for r in self._requirements if not isinstance(r, (AssertRefSnapshotId, AssertTableUUID)))

for producer in self._snapshot_producers:
producer._refresh_for_retry()
producer._validate_concurrency()
updates, requirements = producer._commit()
self._stage(updates, requirements)


class CreateTableTransaction(Transaction):
"""A transaction that involves the creation of a new table."""
Expand Down Expand Up @@ -1961,13 +2053,11 @@ def _build_residual_evaluator(self, spec_id: int) -> Callable[[DataFile], Residu
# The lambda created here is run in multiple threads.
# So we avoid creating _EvaluatorExpression methods bound to a single
# shared instance across multiple threads.
return lambda datafile: (
residual_evaluator_of(
spec=spec,
expr=self.row_filter,
case_sensitive=self.case_sensitive,
schema=self.table_metadata.schema(),
)
return lambda datafile: residual_evaluator_of(
spec=spec,
expr=self.row_filter,
case_sensitive=self.case_sensitive,
schema=self.table_metadata.schema(),
)

@staticmethod
Expand Down
7 changes: 7 additions & 0 deletions pyiceberg/table/snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ def __repr__(self) -> str:
return f"Operation.{self.name}"


class IsolationLevel(str, Enum):
"""Transaction isolation level for concurrent write validation."""

SERIALIZABLE = "serializable"
SNAPSHOT = "snapshot"


class UpdateMetrics:
added_file_size: int
removed_file_size: int
Expand Down
Loading
Loading