Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): add except_cols argument to when_matched_update_all and when_not_matched_insert_all for excluding specific columns #3098

Merged
merged 1 commit into from
Jan 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 78 additions & 4 deletions python/deltalake/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1522,21 +1522,26 @@ def when_matched_update(
self._builder.when_matched_update(updates, predicate)
return self

def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerger":
def when_matched_update_all(
self, predicate: Optional[str] = None, except_cols: Optional[List[str]] = None
) -> "TableMerger":
"""Updating all source fields to target fields, source and target are required to have the same field names.
If a ``predicate`` is specified, then it must evaluate to true for the row to be updated.
If ``except_cols`` is specified, then the columns in the exclude list will not be updated.

Note:
Column names with special characters, such as numbers or spaces should be encapsulated
in backticks: "target.`123column`" or "target.`my column`"

Args:
predicate: SQL like predicate on when to update all columns.

except_cols: List of columns to exclude from update.
Returns:
TableMerger: TableMerger Object

Example:
** Update all columns **

```python
from deltalake import DeltaTable, write_deltalake
import pyarrow as pa
Expand All @@ -1563,6 +1568,35 @@ def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerg
1 2 5
2 3 6
```

** Update all columns except `bar` **

```python
from deltalake import DeltaTable, write_deltalake
import pyarrow as pa

data = pa.table({"foo": [1, 2, 3], "bar": [4, 5, 6]})
write_deltalake("tmp", data)
dt = DeltaTable("tmp")
new_data = pa.table({"foo": [1], "bar": [7]})

(
dt.merge(
source=new_data,
predicate="target.foo = source.foo",
source_alias="source",
target_alias="target")
.when_matched_update_all(except_cols=["bar"])
.execute()
)
{'num_source_rows': 1, 'num_target_rows_inserted': 0, 'num_target_rows_updated': 1, 'num_target_rows_deleted': 0, 'num_target_rows_copied': 2, 'num_output_rows': 3, 'num_target_files_added': 1, 'num_target_files_removed': 1, 'execution_time_ms': ..., 'scan_time_ms': ..., 'rewrite_time_ms': ...}

dt.to_pandas()
foo bar
0 1 4
1 2 5
2 3 6
```
"""
maybe_source_alias = self._builder.source_alias
maybe_target_alias = self._builder.target_alias
Expand All @@ -1572,9 +1606,12 @@ def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerg
(maybe_target_alias + ".") if maybe_target_alias is not None else ""
)

except_columns = except_cols or []

updates = {
f"{trgt_alias}`{col.name}`": f"{src_alias}`{col.name}`"
for col in self._builder.arrow_schema
if col.name not in except_columns
}

self._builder.when_matched_update(updates, predicate)
Expand Down Expand Up @@ -1700,23 +1737,26 @@ def when_not_matched_insert(
return self

def when_not_matched_insert_all(
self, predicate: Optional[str] = None
self, predicate: Optional[str] = None, except_cols: Optional[List[str]] = None
) -> "TableMerger":
"""Insert a new row to the target table, updating all source fields to target fields. Source and target are
required to have the same field names. If a ``predicate`` is specified, then it must evaluate to true for
the new row to be inserted.
the new row to be inserted. If ``except_cols`` is specified, then the columns in the exclude list will not be inserted.

Note:
Column names with special characters, such as numbers or spaces should be encapsulated
in backticks: "target.`123column`" or "target.`my column`"

Args:
predicate: SQL like predicate on when to insert.
except_cols: List of columns to exclude from insert.

Returns:
TableMerger: TableMerger Object

Example:
** Insert all columns **

```python
from deltalake import DeltaTable, write_deltalake
import pyarrow as pa
Expand Down Expand Up @@ -1744,6 +1784,36 @@ def when_not_matched_insert_all(
2 3 6
3 4 7
```

** Insert all columns except `bar` **

```python
from deltalake import DeltaTable, write_deltalake
import pyarrow as pa

data = pa.table({"foo": [1, 2, 3], "bar": [4, 5, 6]})
write_deltalake("tmp", data)
dt = DeltaTable("tmp")
new_data = pa.table({"foo": [4], "bar": [7]})

(
dt.merge(
source=new_data,
predicate='target.foo = source.foo',
source_alias='source',
target_alias='target')
.when_not_matched_insert_all(except_cols=["bar"])
.execute()
)
{'num_source_rows': 1, 'num_target_rows_inserted': 1, 'num_target_rows_updated': 0, 'num_target_rows_deleted': 0, 'num_target_rows_copied': 3, 'num_output_rows': 4, 'num_target_files_added': 1, 'num_target_files_removed': 1, 'execution_time_ms': ..., 'scan_time_ms': ..., 'rewrite_time_ms': ...}

dt.to_pandas().sort_values("foo", ignore_index=True)
foo bar
0 1 4
1 2 5
2 3 6
3 4 NaN
```
"""
maybe_source_alias = self._builder.source_alias
maybe_target_alias = self._builder.target_alias
Expand All @@ -1752,9 +1822,13 @@ def when_not_matched_insert_all(
trgt_alias = (
(maybe_target_alias + ".") if maybe_target_alias is not None else ""
)

except_columns = except_cols or []

updates = {
f"{trgt_alias}`{col.name}`": f"{src_alias}`{col.name}`"
for col in self._builder.arrow_schema
if col.name not in except_columns
}

self._builder.when_not_matched_insert(updates, predicate)
Expand Down
77 changes: 77 additions & 0 deletions python/tests/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,45 @@ def test_merge_when_matched_update_all_wo_predicate(
assert result == expected


def test_merge_when_matched_update_all_with_exclude(
tmp_path: pathlib.Path, sample_table: pa.Table
):
write_deltalake(tmp_path, sample_table, mode="append")

dt = DeltaTable(tmp_path)

source_table = pa.table(
{
"id": pa.array(["4", "5"]),
"price": pa.array([10, 100], pa.int64()),
"sold": pa.array([15, 25], pa.int32()),
"deleted": pa.array([True, True]),
"weight": pa.array([10, 15], pa.int64()),
}
)

dt.merge(
source=source_table,
predicate="t.id = s.id",
source_alias="s",
target_alias="t",
).when_matched_update_all(except_cols=["sold"]).execute()

expected = pa.table(
{
"id": pa.array(["1", "2", "3", "4", "5"]),
"price": pa.array([0, 1, 2, 10, 100], pa.int64()),
"sold": pa.array([0, 1, 2, 3, 4], pa.int32()),
"deleted": pa.array([False, False, False, True, True]),
}
)
result = dt.to_pyarrow_table().sort_by([("id", "ascending")])
last_action = dt.history(1)[0]

assert last_action["operation"] == "MERGE"
assert result == expected


def test_merge_when_matched_update_with_predicate(
tmp_path: pathlib.Path, sample_table: pa.Table
):
Expand Down Expand Up @@ -340,6 +379,44 @@ def test_merge_when_not_matched_insert_all_with_predicate(
assert result == expected


def test_merge_when_not_matched_insert_all_with_exclude(
tmp_path: pathlib.Path, sample_table: pa.Table
):
write_deltalake(tmp_path, sample_table, mode="append")

dt = DeltaTable(tmp_path)

source_table = pa.table(
{
"id": pa.array(["6", "9"]),
"price": pa.array([10, 100], pa.int64()),
"sold": pa.array([10, 20], pa.int32()),
"deleted": pa.array([None, None], pa.bool_()),
}
)

dt.merge(
source=source_table,
source_alias="source",
target_alias="target",
predicate="target.id = source.id",
).when_not_matched_insert_all(except_cols=["sold"]).execute()

expected = pa.table(
{
"id": pa.array(["1", "2", "3", "4", "5", "6", "9"]),
"price": pa.array([0, 1, 2, 3, 4, 10, 100], pa.int64()),
"sold": pa.array([0, 1, 2, 3, 4, None, None], pa.int32()),
"deleted": pa.array([False, False, False, False, False, None, None]),
}
)
result = dt.to_pyarrow_table().sort_by([("id", "ascending")])
last_action = dt.history(1)[0]

assert last_action["operation"] == "MERGE"
assert result == expected


def test_merge_when_not_matched_insert_all_with_predicate_special_column_names(
tmp_path: pathlib.Path, sample_table_with_spaces_numbers: pa.Table
):
Expand Down
Loading