Skip to content

Commit

Permalink
add except cols
Browse files Browse the repository at this point in the history
Signed-off-by: jsj <jsj>
  • Loading branch information
jsj committed Jan 4, 2025
1 parent a639dea commit 439c340
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 4 deletions.
82 changes: 78 additions & 4 deletions python/deltalake/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1522,21 +1522,26 @@ def when_matched_update(
self._builder.when_matched_update(updates, predicate)
return self

def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerger":
def when_matched_update_all(
self, predicate: Optional[str] = None, except_cols: Optional[List[str]] = None
) -> "TableMerger":
"""Updating all source fields to target fields, source and target are required to have the same field names.
If a ``predicate`` is specified, then it must evaluate to true for the row to be updated.
If ``except_cols`` is specified, then the columns in the exclude list will not be updated.
Note:
Column names with special characters, such as numbers or spaces should be encapsulated
in backticks: "target.`123column`" or "target.`my column`"
Args:
predicate: SQL like predicate on when to update all columns.
except_cols: List of columns to exclude from update.
Returns:
TableMerger: TableMerger Object
Example:
** Update all columns **
```python
from deltalake import DeltaTable, write_deltalake
import pyarrow as pa
Expand All @@ -1563,6 +1568,35 @@ def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerg
1 2 5
2 3 6
```
** Update all columns except `bar` **
```python
from deltalake import DeltaTable, write_deltalake
import pyarrow as pa
data = pa.table({"foo": [1, 2, 3], "bar": [4, 5, 6]})
write_deltalake("tmp", data)
dt = DeltaTable("tmp")
new_data = pa.table({"foo": [1], "bar": [7]})
(
dt.merge(
source=new_data,
predicate="target.foo = source.foo",
source_alias="source",
target_alias="target")
.when_matched_update_all(except_cols=["bar"])
.execute()
)
{'num_source_rows': 1, 'num_target_rows_inserted': 0, 'num_target_rows_updated': 1, 'num_target_rows_deleted': 0, 'num_target_rows_copied': 2, 'num_output_rows': 3, 'num_target_files_added': 1, 'num_target_files_removed': 1, 'execution_time_ms': ..., 'scan_time_ms': ..., 'rewrite_time_ms': ...}
dt.to_pandas()
foo bar
0 1 4
1 2 5
2 3 6
```
"""
maybe_source_alias = self._builder.source_alias
maybe_target_alias = self._builder.target_alias
Expand All @@ -1572,9 +1606,12 @@ def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerg
(maybe_target_alias + ".") if maybe_target_alias is not None else ""
)

except_columns = except_cols or []

updates = {
f"{trgt_alias}`{col.name}`": f"{src_alias}`{col.name}`"
for col in self._builder.arrow_schema
if col.name not in except_columns
}

self._builder.when_matched_update(updates, predicate)
Expand Down Expand Up @@ -1700,23 +1737,26 @@ def when_not_matched_insert(
return self

def when_not_matched_insert_all(
self, predicate: Optional[str] = None
self, predicate: Optional[str] = None, except_cols: Optional[List[str]] = None
) -> "TableMerger":
"""Insert a new row to the target table, updating all source fields to target fields. Source and target are
required to have the same field names. If a ``predicate`` is specified, then it must evaluate to true for
the new row to be inserted.
the new row to be inserted. If ``except_cols`` is specified, then the columns in the exclude list will not be inserted.
Note:
Column names with special characters, such as numbers or spaces should be encapsulated
in backticks: "target.`123column`" or "target.`my column`"
Args:
predicate: SQL like predicate on when to insert.
except_cols: List of columns to exclude from insert.
Returns:
TableMerger: TableMerger Object
Example:
** Insert all columns **
```python
from deltalake import DeltaTable, write_deltalake
import pyarrow as pa
Expand Down Expand Up @@ -1744,6 +1784,36 @@ def when_not_matched_insert_all(
2 3 6
3 4 7
```
** Insert all columns except `bar` **
```python
from deltalake import DeltaTable, write_deltalake
import pyarrow as pa
data = pa.table({"foo": [1, 2, 3], "bar": [4, 5, 6]})
write_deltalake("tmp", data)
dt = DeltaTable("tmp")
new_data = pa.table({"foo": [4], "bar": [7]})
(
dt.merge(
source=new_data,
predicate='target.foo = source.foo',
source_alias='source',
target_alias='target')
.when_not_matched_insert_all(except_cols=["bar"])
.execute()
)
{'num_source_rows': 1, 'num_target_rows_inserted': 1, 'num_target_rows_updated': 0, 'num_target_rows_deleted': 0, 'num_target_rows_copied': 3, 'num_output_rows': 4, 'num_target_files_added': 1, 'num_target_files_removed': 1, 'execution_time_ms': ..., 'scan_time_ms': ..., 'rewrite_time_ms': ...}
dt.to_pandas().sort_values("foo", ignore_index=True)
foo bar
0 1 4
1 2 5
2 3 6
3 4 NaN
```
"""
maybe_source_alias = self._builder.source_alias
maybe_target_alias = self._builder.target_alias
Expand All @@ -1752,9 +1822,13 @@ def when_not_matched_insert_all(
trgt_alias = (
(maybe_target_alias + ".") if maybe_target_alias is not None else ""
)

except_columns = except_cols or []

updates = {
f"{trgt_alias}`{col.name}`": f"{src_alias}`{col.name}`"
for col in self._builder.arrow_schema
if col.name not in except_columns
}

self._builder.when_not_matched_insert(updates, predicate)
Expand Down
77 changes: 77 additions & 0 deletions python/tests/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,45 @@ def test_merge_when_matched_update_all_wo_predicate(
assert result == expected


def test_merge_when_matched_update_all_with_exclude(
tmp_path: pathlib.Path, sample_table: pa.Table
):
write_deltalake(tmp_path, sample_table, mode="append")

dt = DeltaTable(tmp_path)

source_table = pa.table(
{
"id": pa.array(["4", "5"]),
"price": pa.array([10, 100], pa.int64()),
"sold": pa.array([15, 25], pa.int32()),
"deleted": pa.array([True, True]),
"weight": pa.array([10, 15], pa.int64()),
}
)

dt.merge(
source=source_table,
predicate="t.id = s.id",
source_alias="s",
target_alias="t",
).when_matched_update_all(except_cols=["sold"]).execute()

expected = pa.table(
{
"id": pa.array(["1", "2", "3", "4", "5"]),
"price": pa.array([0, 1, 2, 10, 100], pa.int64()),
"sold": pa.array([0, 1, 2, 3, 4], pa.int32()),
"deleted": pa.array([False, False, False, True, True]),
}
)
result = dt.to_pyarrow_table().sort_by([("id", "ascending")])
last_action = dt.history(1)[0]

assert last_action["operation"] == "MERGE"
assert result == expected


def test_merge_when_matched_update_with_predicate(
tmp_path: pathlib.Path, sample_table: pa.Table
):
Expand Down Expand Up @@ -340,6 +379,44 @@ def test_merge_when_not_matched_insert_all_with_predicate(
assert result == expected


def test_merge_when_not_matched_insert_all_with_exclude(
tmp_path: pathlib.Path, sample_table: pa.Table
):
write_deltalake(tmp_path, sample_table, mode="append")

dt = DeltaTable(tmp_path)

source_table = pa.table(
{
"id": pa.array(["6", "9"]),
"price": pa.array([10, 100], pa.int64()),
"sold": pa.array([10, 20], pa.int32()),
"deleted": pa.array([None, None], pa.bool_()),
}
)

dt.merge(
source=source_table,
source_alias="source",
target_alias="target",
predicate="target.id = source.id",
).when_not_matched_insert_all(except_cols=["sold"]).execute()

expected = pa.table(
{
"id": pa.array(["1", "2", "3", "4", "5", "6", "9"]),
"price": pa.array([0, 1, 2, 3, 4, 10, 100], pa.int64()),
"sold": pa.array([0, 1, 2, 3, 4, None, None], pa.int32()),
"deleted": pa.array([False, False, False, False, False, None, None]),
}
)
result = dt.to_pyarrow_table().sort_by([("id", "ascending")])
last_action = dt.history(1)[0]

assert last_action["operation"] == "MERGE"
assert result == expected


def test_merge_when_not_matched_insert_all_with_predicate_special_column_names(
tmp_path: pathlib.Path, sample_table_with_spaces_numbers: pa.Table
):
Expand Down

0 comments on commit 439c340

Please sign in to comment.