diff --git a/pyiceberg/expressions/__init__.py b/pyiceberg/expressions/__init__.py index ef4cb2506e..8928c1c3a3 100644 --- a/pyiceberg/expressions/__init__.py +++ b/pyiceberg/expressions/__init__.py @@ -70,6 +70,20 @@ def __or__(self, other: BooleanExpression) -> BooleanExpression: return Or(self, other) + def __bool__(self) -> bool: + """Reject truthiness checks on non-constant expressions. + + Truthiness is only defined for the constant expressions ``AlwaysTrue`` and + ``AlwaysFalse``, which override this method. Evaluating a predicate such as + ``if EqualTo("x", 1):`` is almost always a mistake; use ``~expr`` to negate + an expression or compare explicitly against ``AlwaysTrue()``/``AlwaysFalse()``. + """ + raise TypeError( + f"The truth value of {type(self).__name__} is ambiguous. " + "Truthiness is only defined for AlwaysTrue() and AlwaysFalse(); " + "use ~expr to negate an expression or compare against AlwaysTrue()/AlwaysFalse()." + ) + @model_validator(mode="wrap") @classmethod def handle_primitive_type(cls, v: Any, handler: ValidatorFunctionWrapHandler) -> BooleanExpression: @@ -455,6 +469,10 @@ def __invert__(self) -> AlwaysFalse: """Transform the Expression into its negated version.""" return AlwaysFalse() + def __bool__(self) -> bool: + """Return True, the constant value of this expression.""" + return True + def __str__(self) -> str: """Return the string representation of the AlwaysTrue class.""" return "AlwaysTrue()" @@ -473,6 +491,10 @@ def __invert__(self) -> AlwaysTrue: """Transform the Expression into its negated version.""" return AlwaysTrue() + def __bool__(self) -> bool: + """Return False, the constant value of this expression.""" + return False + def __str__(self) -> str: """Return the string representation of the AlwaysFalse class.""" return "AlwaysFalse()" diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 63b87d290e..bc4ac92624 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -2103,7 +2103,7 @@ def from_rest_response( return FileScanTask( data_file=data_file, delete_files=resolved_deletes, - residual=rest_task.residual_filter if rest_task.residual_filter else ALWAYS_TRUE, + residual=rest_task.residual_filter if rest_task.residual_filter is not None else ALWAYS_TRUE, ) diff --git a/tests/catalog/test_scan_planning_models.py b/tests/catalog/test_scan_planning_models.py index f2c80cfb9b..e6af338ebe 100644 --- a/tests/catalog/test_scan_planning_models.py +++ b/tests/catalog/test_scan_planning_models.py @@ -39,6 +39,7 @@ ) from pyiceberg.expressions import AlwaysTrue, EqualTo, Reference from pyiceberg.manifest import FileFormat +from pyiceberg.table import FileScanTask TEST_URI = "https://iceberg-test-catalog/" @@ -242,6 +243,24 @@ def test_scan_task_with_residual_filter_true() -> None: assert isinstance(task.residual_filter, AlwaysTrue) +def test_from_rest_response_preserves_non_constant_residual_filter() -> None: + data = { + "data-file": _rest_data_file(), + "residual-filter": {"type": "eq", "term": "x", "value": 1}, + } + rest_task = RESTFileScanTask.model_validate(data) + task = FileScanTask.from_rest_response(rest_task, []) + assert task.residual == EqualTo(Reference("x"), 1) + + +def test_from_rest_response_defaults_missing_residual_filter_to_always_true() -> None: + data = {"data-file": _rest_data_file()} + rest_task = RESTFileScanTask.model_validate(data) + assert rest_task.residual_filter is None + task = FileScanTask.from_rest_response(rest_task, []) + assert task.residual == AlwaysTrue() + + def test_empty_scan_tasks() -> None: data: dict[str, Any] = { "delete-files": [], diff --git a/tests/expressions/test_expressions.py b/tests/expressions/test_expressions.py index 8ce48a6897..2e93466425 100644 --- a/tests/expressions/test_expressions.py +++ b/tests/expressions/test_expressions.py @@ -639,6 +639,33 @@ def test_invert_always() -> None: assert ~AlwaysTrue() == AlwaysFalse() +def test_always_bool() -> None: + assert bool(AlwaysTrue()) is True + assert bool(AlwaysFalse()) is False + + +def test_always_bool_control_flow() -> None: + assert (1 if AlwaysTrue() else 0) == 1 + assert (1 if AlwaysFalse() else 0) == 0 + assert not AlwaysFalse() + assert not (not AlwaysTrue()) + + +@pytest.mark.parametrize( + "expression", + [ + EqualTo("x", 1), + IsNull("x"), + And(EqualTo("x", 1), IsNull("y")), + Or(EqualTo("x", 1), IsNull("y")), + Not(EqualTo("x", 1)), + ], +) +def test_non_constant_expression_bool_raises(expression: BooleanExpression) -> None: + with pytest.raises(TypeError, match="truth value"): + bool(expression) + + def test_accessor_base_class() -> None: """Test retrieving a value at a position of a container using an accessor""" diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 532311899d..0fa5ce2a03 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -1137,7 +1137,7 @@ def _set_spec_id(datafile: DataFile) -> DataFile: ), io=PyArrowFileIO(), projected_schema=schema, - row_filter=expr or AlwaysTrue(), + row_filter=expr if expr is not None else AlwaysTrue(), case_sensitive=True, ).to_table( tasks=[