diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e60e254f4..8857bf5ceb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,12 +5,25 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Breaking Changes +- All optional arguments of `SubspaceDiscrete.from_simplex` after `simplex_parameters` + are now keyword-only + +### Added +- `coefficients` attribute for `DiscreteSumConstraint`, enabling weighted sums. Follows + the same pattern as `ContinuousLinearConstraint.coefficients` +- `simplex_coefficients` keyword argument to `SubspaceDiscrete.from_simplex` for + weighted simplex sum constraints + ### Changed - `BOTORCH` GP preset now includes `BetaPrior(2.5, 1.5)` for the task covariance kernel in multi-task scenarios, matching BoTorch's `MultiTaskGP` defaults introduced in version `0.18.0` - The `BOTORCH` GP preset now requires BoTorch `>= 0.18.0` and raises an `IncompatibilityError` if an older version is installed +- `DiscreteSumConstraint`, `ContinuousLinearConstraint`, and + `SubspaceDiscrete.from_simplex` now forbid 0 as coefficients +- `SubspaceDiscrete.from_simplex` no longer requires non-negative parameter values ## [0.15.0] - 2026-06-11 ### Breaking Changes diff --git a/baybe/constraints/continuous.py b/baybe/constraints/continuous.py index cfb2c4ebec..4714b8d584 100644 --- a/baybe/constraints/continuous.py +++ b/baybe/constraints/continuous.py @@ -81,6 +81,8 @@ def _validate_coefficients( # noqa: DOC101, DOC103 "The given 'coefficients' list must have one floating point entry for " "each entry in 'parameters'." ) + if any(c == 0.0 for c in coefficients): + raise ValueError("All entries in 'coefficients' must be non-zero.") @coefficients.default def _default_coefficients(self) -> tuple[float, ...]: diff --git a/baybe/constraints/discrete.py b/baybe/constraints/discrete.py index b9635faffa..d43e2f3fbe 100644 --- a/baybe/constraints/discrete.py +++ b/baybe/constraints/discrete.py @@ -3,15 +3,16 @@ from __future__ import annotations import gc -from collections.abc import Callable +from collections.abc import Callable, Sequence from functools import reduce from typing import TYPE_CHECKING, Any, ClassVar, cast +import cattrs import numpy as np import numpy.typing as npt import pandas as pd from attrs import define, field -from attrs.validators import in_, min_len +from attrs.validators import deep_iterable, in_, min_len from typing_extensions import override from baybe.constraints.base import CardinalityConstraint, DiscreteConstraint @@ -26,6 +27,7 @@ block_serialization_hook, converter, ) +from baybe.utils.validation import finite_float if TYPE_CHECKING: import polars as pl @@ -77,7 +79,11 @@ def get_invalid_polars(self) -> pl.Expr: @define class DiscreteSumConstraint(DiscreteConstraint): - """Class for modelling sum constraints.""" + """Class for modelling sum constraints. + + The constraint evaluates whether the (optionally weighted) sum of the specified + parameters satisfies the given threshold condition. + """ # IMPROVE: refactor `SumConstraint` and `ProdConstraint` to avoid code copying @@ -94,9 +100,45 @@ class DiscreteSumConstraint(DiscreteConstraint): condition: ThresholdCondition = field() """The condition modeled by this constraint.""" + coefficients: tuple[float, ...] = field( + converter=lambda x: cattrs.structure(x, tuple[float, ...]), + validator=deep_iterable(member_validator=finite_float), + ) + """The coefficients for the weighted sum, one per entry in ``parameters``. + + Defaults to all-ones, i.e. an unweighted sum.""" + + @coefficients.default + def _default_coefficients(self) -> tuple[float, ...]: + """Return equal weight coefficients as default.""" + return (1.0,) * len(self.parameters) + + @coefficients.validator + def _validate_coefficients( # noqa: DOC101, DOC103 + self, _: Any, coefficients: Sequence[float] + ) -> None: + """Validate the coefficients. + + Raises: + ValueError: If the number of coefficients does not match the number of + parameters. + """ + if len(self.parameters) != len(coefficients): + raise ValueError( + "The given 'coefficients' list must have one floating point entry for " + "each entry in 'parameters'." + ) + if any(c == 0.0 for c in coefficients): + raise ValueError("All entries in 'coefficients' must be non-zero.") + @override def _get_invalid(self, df: pd.DataFrame, /) -> pd.Index: - evaluate_df = df[self.parameters].sum(axis=1) + evaluate_df = pd.Series( + sum( + df[p].to_numpy() * c for p, c in zip(self.parameters, self.coefficients) + ), + index=df.index, + ) mask_bad = ~self.condition.evaluate(evaluate_df) return df.index[mask_bad] @@ -105,7 +147,8 @@ def _get_invalid(self, df: pd.DataFrame, /) -> pd.Index: def get_invalid_polars(self) -> pl.Expr: from baybe._optional.polars import polars as pl - return self.condition.to_polars(pl.sum_horizontal(self.parameters)).not_() + weighted = [pl.col(p) * c for p, c in zip(self.parameters, self.coefficients)] + return self.condition.to_polars(pl.sum_horizontal(weighted)).not_() @define diff --git a/baybe/searchspace/discrete.py b/baybe/searchspace/discrete.py index b8c3e1c0ae..573d6e6919 100644 --- a/baybe/searchspace/discrete.py +++ b/baybe/searchspace/discrete.py @@ -269,6 +269,8 @@ def from_simplex( cls, max_sum: float, simplex_parameters: Sequence[NumericalDiscreteParameter], + *, + simplex_coefficients: Sequence[float] | None = None, product_parameters: Sequence[DiscreteParameter] | None = None, constraints: Sequence[DiscreteConstraint] | None = None, min_nonzero: int = 0, @@ -290,8 +292,12 @@ def from_simplex( significantly faster construction. Args: - max_sum: The maximum sum of the parameter values defining the simplex size. + max_sum: The maximum (weighted) sum of the parameter values defining the + simplex size. simplex_parameters: The parameters to be used for the simplex construction. + simplex_coefficients: Optional coefficients for the weighted sum, one per + entry in ``simplex_parameters``. Defaults to all-ones, i.e. an + unweighted sum. product_parameters: Optional parameters that enter in form of a Cartesian product. constraints: See :class:`baybe.searchspace.core.SearchSpace`. @@ -304,8 +310,9 @@ def from_simplex( tolerance: Numerical tolerance used to validate the simplex constraint. Raises: - ValueError: If the passed simplex parameters are not suitable for a simplex - construction. + ValueError: If the length of ``simplex_coefficients`` does not match the + number of ``simplex_parameters``. + ValueError: If ``simplex_coefficients`` contains any zeros. ValueError: If the passed product parameters are not discrete. ValueError: If the passed simplex parameters and product parameters are not disjoint. @@ -325,6 +332,8 @@ def from_simplex( constraints = [] if max_nonzero is None: max_nonzero = len(simplex_parameters) + if simplex_coefficients is None: + simplex_coefficients = [1.0] * len(simplex_parameters) # Validate constraints validate_constraints(constraints, [*simplex_parameters, *product_parameters]) @@ -343,6 +352,18 @@ def from_simplex( f"must be of subclasses of '{DiscreteParameter.__name__}'." ) + # Validate coefficients length + if len(simplex_coefficients) != len(simplex_parameters): + raise ValueError( + f"'simplex_coefficients' must have one entry per 'simplex_parameters' " + f"entry, but got {len(simplex_coefficients)} coefficient(s) for " + f"{len(simplex_parameters)} parameter(s)." + ) + + # Validate no zero coefficients + if any(c == 0.0 for c in simplex_coefficients): + raise ValueError("All entries in 'simplex_coefficients' must be non-zero.") + # Validate no overlap between simplex parameters and product parameters simplex_parameters_names = {p.name for p in simplex_parameters} product_parameters_names = {p.name for p in product_parameters} @@ -364,79 +385,54 @@ def from_simplex( if len(simplex_parameters) < 1: return cls.from_product(product_parameters, constraints) - # Validate non-negativity - min_values = [min(p.values) for p in simplex_parameters] - max_values = [max(p.values) for p in simplex_parameters] - if not (min(min_values) >= 0.0): + # Compute per-parameter minimum weighted contributions. + # For a positive coefficient c the minimum contribution is c*min_raw; for a + # negative coefficient the ordering flips and it becomes c*max_raw. Taking + # min of both products handles any real coefficient correctly. + min_raw = [min(p.values) for p in simplex_parameters] + max_raw = [max(p.values) for p in simplex_parameters] + coeffs = np.asarray(simplex_coefficients, dtype=active_settings.DTypeFloatNumpy) + if not np.isfinite(coeffs).all(): raise ValueError( - f"All simplex_parameters passed to '{cls.from_simplex.__name__}' " - f"must have non-negative values only." + f"All simplex_coefficients passed to '{cls.from_simplex.__name__}' " + f"must be finite numbers." ) + min_weighted = np.array( + [min(c * lo, c * hi) for c, lo, hi in zip(coeffs, min_raw, max_raw)] + ) - def drop_invalid( - df: pd.DataFrame, - max_sum: float, - boundary_only: bool, - min_nonzero: int | None = None, - max_nonzero: int | None = None, - ) -> None: - """Drop rows that violate the specified simplex constraint. - - Args: - df: The dataframe whose rows should satisfy the simplex constraint. - max_sum: The maximum row sum defining the simplex size. - boundary_only: Flag to control if the points represented by the rows - may lie inside the simplex or on its boundary only. - min_nonzero: Minimum number of nonzero parameters required per row. - max_nonzero: Maximum number of nonzero parameters allowed per row. - """ - # Apply sum constraints - row_sums = df.sum(axis=1) - mask_violated = row_sums > max_sum + tolerance - if boundary_only: - mask_violated |= row_sums < max_sum - tolerance - - # Apply optional nonzero constraints - if (min_nonzero is not None) or (max_nonzero is not None): - n_nonzero = (df != 0.0).sum(axis=1) - if min_nonzero is not None: - mask_violated |= n_nonzero < min_nonzero - if max_nonzero is not None: - mask_violated |= n_nonzero > max_nonzero - - # Remove violating rows - idxs_to_drop = df[mask_violated].index - df.drop(index=idxs_to_drop, inplace=True) - - # Get the minimum sum contributions to come in the upcoming joins (the - # first item is the minimum possible sum of all parameters starting from the - # second parameter, the second item is the minimum possible sum starting from - # the third parameter, and so on ...) - min_sum_upcoming = np.cumsum(min_values[:0:-1])[::-1] - - # Get the min/max number of nonzero values to come in the upcoming joins (the - # first item is the min/max number of nonzero parameters starting from the - # second parameter, the second item is the min/max number starting from - # the third parameter, and so on ...) - min_nonzero_upcoming = np.cumsum((np.asarray(min_values) > 0.0)[:0:-1])[::-1] - max_nonzero_upcoming = np.cumsum((np.asarray(max_values) > 0.0)[:0:-1])[::-1] - - # Incrementally build up the space, dropping invalid configuration along the - # way. More specifically: - # * After having cross-joined a new parameter, there must - # be enough "room" left for the remaining parameters to fit. That is, - # configurations of the current parameter subset that exceed the desired - # total value minus the minimum contribution to come from the yet-to-be-added - # parameters can be already discarded, because it is already clear that - # the total sum will be exceeded once all joins are completed. - # * Analogously, there must be enough "nonzero slots" left for the yet to be - # joined parameters, i.e. parameter subset configurations can be discarded - # where the number of nonzero parameters already exceeds the maximum number - # of nonzeros minus the number of nonzeros to come, because it is already - # clear that the maximum will be exceeded once all joins are completed. - # * Similarly, it can be verified for each parameter that there are still - # enough nonzero parameters to come to even reach the minimum - # desired number of nonzero after all joins. + # Get the minimum weighted sum contributions to come in the upcoming joins (the + # first item is the minimum possible weighted sum of all parameters starting + # from the second parameter, the second item is the minimum possible weighted + # sum starting from the third parameter, and so on ...) + min_sum_upcoming = np.cumsum(min_weighted[:0:-1])[::-1] + + # Get the min/max number of nonzero values to come in the upcoming joins. + # Nonzero counting is based on raw parameter values, not weighted values, + # because the cardinality constraint counts zero/nonzero entries regardless + # of the coefficient signs. + min_nonzero_upcoming = np.cumsum((np.asarray(min_raw) > 0.0)[:0:-1])[::-1] + max_nonzero_upcoming = np.cumsum((np.asarray(max_raw) > 0.0)[:0:-1])[::-1] + + # Incrementally build up the space as a numpy array, dropping invalid + # configurations along the way. Working with raw numpy avoids pandas overhead + # (index management, BlockManager, merge machinery) in the hot loop. + # + # After having cross-joined a new parameter, there must be enough "room" left + # for the remaining parameters to fit. That is, configurations of the current + # parameter subset that exceed the desired total value minus the minimum + # contribution to come from the yet-to-be-added parameters can be already + # discarded, because it is already clear that the total sum will be exceeded + # once all joins are completed. Analogously, nonzero cardinality bounds are + # checked at each step. + # + # Instead of materializing the full cross-product before filtering, we use + # broadcasting to compute the validity mask in 2D (n_old, n_new) and only + # materialize the surviving combinations. This avoids allocating large + # intermediate arrays that are mostly discarded. + arr: np.ndarray + partial_sums: np.ndarray + nz_counts: np.ndarray for i, ( param, min_sum_to_go, @@ -450,27 +446,55 @@ def drop_invalid( np.append(max_nonzero_upcoming, 0), ) ): + values = np.asarray(param.values, dtype=active_settings.DTypeFloatNumpy) + threshold = (max_sum - min_sum_to_go) + tolerance + effective_min = min_nonzero - max_nonzero_to_go + effective_max = max_nonzero - min_nonzero_to_go + if i == 0: - exp_rep = pd.DataFrame({param.name: param.values}) - else: - exp_rep = pd.merge( - exp_rep, pd.DataFrame({param.name: param.values}), how="cross" - ) - drop_invalid( - exp_rep, - max_sum=max_sum - min_sum_to_go, - # the maximum possible number of nonzeros to come dictates if we - # can achieve our minimum constraint in the end: - min_nonzero=min_nonzero - max_nonzero_to_go, - # the minimum possible number of nonzeros to come dictates if we - # can stay below the targeted maximum in the end: - max_nonzero=max_nonzero - min_nonzero_to_go, - boundary_only=False, - ) + partial_sums = values * coeffs[0] + nz_counts = (values != 0.0).astype(np.intp) + + # Apply constraints directly on first parameter + mask = partial_sums <= threshold + if effective_min > 0: + mask &= nz_counts >= effective_min + if effective_max < len(simplex_parameters): + mask &= nz_counts <= effective_max + + arr = values[mask].reshape(-1, 1) + partial_sums = partial_sums[mask] + nz_counts = nz_counts[mask] + continue + + # Compute weighted sums via broadcasting: (n_old, n_new) + new_contributions = values * coeffs[i] + total_sums = partial_sums[:, None] + new_contributions[None, :] + + # Build 2D validity mask from sum constraint + mask_2d = total_sums <= threshold + + # Cardinality check via broadcasting + new_nz = (values != 0.0).astype(np.intp) + total_nz = nz_counts[:, None] + new_nz[None, :] + if effective_min > 0: + mask_2d &= total_nz >= effective_min + if effective_max < len(simplex_parameters): + mask_2d &= total_nz <= effective_max + + # Extract surviving indices and materialize only those rows + old_idx, new_idx = np.where(mask_2d) + arr = np.column_stack([arr[old_idx], values[new_idx].reshape(-1, 1)]) + partial_sums = total_sums[old_idx, new_idx] + nz_counts = total_nz[old_idx, new_idx] # If requested, keep only the boundary values if boundary_only: - drop_invalid(exp_rep, max_sum, boundary_only=True) + mask = np.abs(partial_sums - max_sum) <= tolerance + arr = arr[mask] + + # Wrap in DataFrame + exp_rep = pd.DataFrame(arr, columns=[p.name for p in simplex_parameters]) # Merge product parameters and apply constraints incrementally exp_rep = build_constrained_product( @@ -765,12 +789,29 @@ def validate_simplex_subspace_from_config(specs: dict, _) -> None: specs["simplex_parameters"], list[NumericalDiscreteParameter] ) - if not all(min(p.values) >= 0.0 for p in simplex_parameters): - raise ValueError( - f"All simplex_parameters passed to " - f"'{SubspaceDiscrete.from_simplex.__name__}' must have non-negative " - f"values only." - ) + simplex_coefficients = specs.get("simplex_coefficients", None) + if simplex_coefficients is not None: + try: + simplex_coefficients = converter.structure( + simplex_coefficients, list[float] + ) + except (IterableValidationError, TypeError, ValueError) as exc: + raise ValueError( + "'simplex_coefficients' must be a list of numeric values." + ) from exc + + if len(simplex_coefficients) != len(simplex_parameters): + raise ValueError( + f"'simplex_coefficients' must have one entry per " + f"'simplex_parameters' entry, but got " + f"{len(simplex_coefficients)} coefficient(s) for " + f"{len(simplex_parameters)} parameter(s)." + ) + + if any(c == 0.0 for c in simplex_coefficients): + raise ValueError( + "All entries in 'simplex_coefficients' must be non-zero." + ) product_parameters = specs.get("product_parameters", []) if product_parameters: diff --git a/tests/constraints/test_constraints_discrete.py b/tests/constraints/test_constraints_discrete.py index 9273ae13bf..fb850aae91 100644 --- a/tests/constraints/test_constraints_discrete.py +++ b/tests/constraints/test_constraints_discrete.py @@ -1,8 +1,14 @@ """Test for imposing discrete constraints.""" +import itertools import math +import pandas as pd import pytest +from pytest import param + +from baybe.constraints.conditions import ThresholdCondition +from baybe.constraints.discrete import DiscreteSumConstraint @pytest.fixture( @@ -275,3 +281,31 @@ def test_cardinality(campaign): min_cardinality = 1 max_cardinality = 2 assert non_zeros.between(min_cardinality, max_cardinality).all() + + +@pytest.mark.parametrize( + ("coefficients", "threshold", "operator", "n_invalid"), + [ + param(None, 1.0, "<=", 3, id="default"), + param((1.0, 1.0), 1.0, "<=", 3, id="all-ones"), + param((2.0, 1.0), 1.0, "<=", 5, id="scaled"), + param((1.0, -1.0), 0.5, "<=", 1, id="negative"), + param((1.0, 1.0), 1.0, "=", 6, id="equality"), + ], +) +def test_sum_constraint_coefficients(coefficients, threshold, operator, n_invalid): + """DiscreteSumConstraint filters correctly with default and custom coefficients.""" + kwargs = {} if coefficients is None else {"coefficients": coefficients} + constraint = DiscreteSumConstraint( + parameters=["A", "B"], + condition=ThresholdCondition(threshold=threshold, operator=operator), + **kwargs, + ) + df = pd.DataFrame( + list(itertools.product([0.0, 0.5, 1.0], repeat=2)), columns=["A", "B"] + ) + coeffs = coefficients or (1.0, 1.0) + weighted = df["A"] * coeffs[0] + df["B"] * coeffs[1] + expected = df.index[~ThresholdCondition(threshold, operator).evaluate(weighted)] + assert list(constraint.get_invalid(df)) == list(expected) + assert len(constraint.get_invalid(df)) == n_invalid diff --git a/tests/constraints/test_constraints_polars.py b/tests/constraints/test_constraints_polars.py index adbb1c5b2c..f5d772ae5c 100644 --- a/tests/constraints/test_constraints_polars.py +++ b/tests/constraints/test_constraints_polars.py @@ -2,6 +2,7 @@ import pytest from pandas.testing import assert_frame_equal +from pytest import param from baybe._optional.info import POLARS_INSTALLED from baybe.constraints import ( @@ -51,25 +52,10 @@ def _lazyframe_from_product(parameters): return res -@pytest.mark.parametrize("parameter_names", [["Fraction_1", "Fraction_2"]]) -@pytest.mark.parametrize("constraint_names", [["Constraint_8"]]) -def test_polars_prodsum1(parameters, constraints): - """Tests Polars implementation of sum constraint.""" - ldf = _lazyframe_from_product(parameters) - ldf = _apply_constraint_filter_polars(ldf, constraints) - - # Number of entries with 1,2-sum above 150 - ldf = ldf.with_columns(sum=pl.sum_horizontal(["Fraction_1", "Fraction_2"])) - ldf = ldf.filter(pl.col("sum") > 150) - num_entries = len(ldf.collect()) - - assert num_entries == 0 - - @pytest.mark.parametrize("parameter_names", [["Fraction_1", "Fraction_2"]]) @pytest.mark.parametrize("constraint_names", [["Constraint_9"]]) -def test_polars_prodsum2(parameters, constraints): - """Tests Polars implementation of product constrain.""" +def test_polars_product_constraint(parameters, constraints): + """Tests Polars implementation of product constraint.""" ldf = _lazyframe_from_product(parameters) ldf = _apply_constraint_filter_polars(ldf, constraints) @@ -85,20 +71,44 @@ def test_polars_prodsum2(parameters, constraints): assert num_entries == 0 +@pytest.mark.parametrize( + ("coefficients", "threshold", "operator"), + [ + param(None, 150.0, "<=", id="unweighted-le"), + param(None, 100.0, "=", id="unweighted-eq"), + param((2.0, 1.0), 150.0, "<=", id="weighted-le"), + param((1.0, -1.0), 50.0, "<=", id="negative-le"), + param((0.5, 0.5), 50.0, "=", id="weighted-eq"), + ], +) @pytest.mark.parametrize("parameter_names", [["Fraction_1", "Fraction_2"]]) -@pytest.mark.parametrize("constraint_names", [["Constraint_10"]]) -def test_polars_prodsum3(parameters, constraints): - """Tests Polars implementation of exact sum constraint.""" +def test_polars_sum_constraint(parameters, coefficients, threshold, operator): + """Polars and Pandas paths produce correct and identical results.""" + names = [p.name for p in parameters] + kwargs = {} if coefficients is None else {"coefficients": coefficients} + condition = ThresholdCondition(threshold=threshold, operator=operator) + constraint = DiscreteSumConstraint(parameters=names, condition=condition, **kwargs) + coeffs = coefficients or (1.0,) * len(parameters) + ldf = _lazyframe_from_product(parameters) - ldf = _apply_constraint_filter_polars(ldf, constraints) + df_pd = parameter_cartesian_prod_pandas(parameters) - # Number of entries with sum unequal to 100 - ldf = ldf.with_columns(sum=pl.sum_horizontal(["Fraction_1", "Fraction_2"])) - df = ldf.select(abs(pl.col("sum") - 100)).filter(pl.col("sum") > 0.01).collect() + _apply_constraint_filter_pandas(df_pd, [constraint]) + df_pl = _apply_constraint_filter_polars(ldf, [constraint]).collect().to_pandas() - num_entries = len(df) + # Correctness: all remaining rows satisfy the constraint + weighted_pd = sum(df_pd[n] * c for n, c in zip(names, coeffs)) + assert condition.evaluate(weighted_pd).all() - assert num_entries == 0 + weighted_pl = sum(df_pl[n] * c for n, c in zip(names, coeffs)) + assert condition.evaluate(weighted_pl).all() + + # Consistency: both paths agree + cols = df_pd.columns.tolist() + assert_frame_equal( + df_pd.sort_values(cols).reset_index(drop=True), + df_pl.sort_values(cols).reset_index(drop=True), + ) @pytest.mark.parametrize( diff --git a/tests/hypothesis_strategies/alternative_creation/test_searchspace.py b/tests/hypothesis_strategies/alternative_creation/test_searchspace.py index 662e898134..e5ccbde0b0 100644 --- a/tests/hypothesis_strategies/alternative_creation/test_searchspace.py +++ b/tests/hypothesis_strategies/alternative_creation/test_searchspace.py @@ -1,5 +1,7 @@ """Test alternative ways of creation not considered in the strategies.""" +import itertools + import hypothesis.strategies as st import numpy as np import pandas as pd @@ -8,6 +10,8 @@ from pandas.testing import assert_frame_equal from pytest import param +from baybe.constraints.conditions import ThresholdCondition +from baybe.constraints.discrete import DiscreteSumConstraint from baybe.parameters import ( CategoricalParameter, NumericalContinuousParameter, @@ -115,7 +119,7 @@ def test_discrete_searchspace_creation_from_degenerate_dataframe(): @pytest.mark.parametrize("boundary_only", (False, True)) @given( parameters=st.lists( - numerical_discrete_parameters(min_value=0.0, max_value=1.0), + numerical_discrete_parameters(min_value=-1.0, max_value=1.0), min_size=2, max_size=5, unique_by=lambda x: x.name, @@ -196,3 +200,119 @@ def test_discrete_space_creation_from_simplex_restricted(boundary_only): assert n_nonzero.max() == 4 assert len(subspace.parameters) == len(subspace.exp_rep.columns) assert all(p.name in subspace.exp_rep.columns for p in subspace.parameters) + + +_simplex_params = [ + NumericalDiscreteParameter(name="A", values=[0.0, 0.5, 1.0]), + NumericalDiscreteParameter(name="B", values=[0.0, 0.5, 1.0]), + NumericalDiscreteParameter(name="C", values=[0.0, 0.5, 1.0]), +] + +_simplex_params_negative = [ + NumericalDiscreteParameter(name="A", values=[-1.0, -0.5, 0.0, 0.5, 1.0]), + NumericalDiscreteParameter(name="B", values=[-1.0, -0.5, 0.0, 0.5, 1.0]), + NumericalDiscreteParameter(name="C", values=[-1.0, -0.5, 0.0, 0.5, 1.0]), +] + + +def _brute_force_weighted_simplex( + params, max_sum, coefficients, *, boundary_only=False, tol=1e-9 +): + """Return all combinations satisfying the weighted simplex constraint.""" + df = pd.DataFrame( + list(itertools.product(*[p.values for p in params])), + columns=[p.name for p in params], + ) + weighted = sum(df[p.name] * c for p, c in zip(params, coefficients)) + mask = weighted <= max_sum + tol + if boundary_only: + mask &= weighted >= max_sum - tol + return df[mask].reset_index(drop=True) + + +@pytest.mark.parametrize( + ("params", "coefficients", "max_sum", "boundary_only"), + [ + param(_simplex_params, None, 1.0, False, id="default"), + param(_simplex_params, [1.0, 1.0, 1.0], 1.0, False, id="explicit-ones"), + param(_simplex_params, [2.0, 1.0, 0.5], 1.5, False, id="positive"), + param(_simplex_params, [2.0, 1.0, 0.5], 1.5, True, id="positive-boundary"), + param(_simplex_params, [1.0, -0.5, 2.0], 1.0, False, id="mixed-sign"), + param(_simplex_params_negative, None, 0.5, False, id="negative-values-default"), + param( + _simplex_params_negative, + [1.0, -0.5, 2.0], + 0.5, + False, + id="negative-values-mixed-sign", + ), + param( + _simplex_params_negative, + [1.0, 1.0, 1.0], + 0.0, + True, + id="negative-values-boundary", + ), + ], +) +def test_discrete_space_creation_from_simplex_coefficients( + params, coefficients, max_sum, boundary_only +): + """Simplex subspace with coefficients matches brute-force and from_product.""" + coeffs = coefficients or [1.0, 1.0, 1.0] + cols = [p.name for p in params] + + # Ground truth via brute force + expected = _brute_force_weighted_simplex( + params, max_sum, coeffs, boundary_only=boundary_only + ) + expected = expected.sort_values(cols).reset_index(drop=True) + + # from_simplex + result_simplex = ( + SubspaceDiscrete.from_simplex( + max_sum, + params, + simplex_coefficients=coefficients, + boundary_only=boundary_only, + ) + .exp_rep.sort_values(cols) + .reset_index(drop=True) + ) + assert_frame_equal(result_simplex, expected, check_dtype=False) + + # from_product with equivalent constraint + operator = "=" if boundary_only else "<=" + constraint = DiscreteSumConstraint( + parameters=cols, + condition=ThresholdCondition(threshold=max_sum, operator=operator), + coefficients=tuple(coeffs), + ) + result_product = ( + SubspaceDiscrete.from_product(params, constraints=[constraint]) + .exp_rep.sort_values(cols) + .reset_index(drop=True) + ) + assert_frame_equal(result_product, expected, check_dtype=False) + + +@pytest.mark.parametrize( + ("simplex_coefficients", "match"), + [ + param( + [1.0], "'simplex_coefficients' must have one entry", id="length-mismatch" + ), + param([1.0, 0.0], "'simplex_coefficients' must be non-zero", id="zero-coeff"), + ], +) +def test_from_simplex_invalid_coefficients(simplex_coefficients, match): + """Invalid simplex_coefficients raise a ValueError.""" + with pytest.raises(ValueError, match=match): + SubspaceDiscrete.from_simplex( + 1.0, + [ + NumericalDiscreteParameter(name="x", values=[0.0, 0.5, 1.0]), + NumericalDiscreteParameter(name="y", values=[0.0, 0.5, 1.0]), + ], + simplex_coefficients=simplex_coefficients, + ) diff --git a/tests/hypothesis_strategies/constraints.py b/tests/hypothesis_strategies/constraints.py index e1f1014833..822c5fe2c3 100644 --- a/tests/hypothesis_strategies/constraints.py +++ b/tests/hypothesis_strategies/constraints.py @@ -27,6 +27,9 @@ from baybe.parameters.numerical import NumericalDiscreteParameter from tests.hypothesis_strategies.basic import finite_floats +_nonzero_finite_floats = finite_floats().filter(lambda x: x != 0.0) +"""A strategy producing non-zero finite floats.""" + def sub_selection_conditions(superset: list[Any] | None = None): """Generate :class:`baybe.constraints.conditions.SubSelectionCondition`.""" @@ -174,7 +177,9 @@ def discrete_permutation_invariance_constraints( return DiscretePermutationInvarianceConstraint(parameter_names, dependencies) +@st.composite def _discrete_constraints( + draw: st.DrawFn, constraint_type: ( type[DiscreteSumConstraint] | type[DiscreteProductConstraint] @@ -185,16 +190,22 @@ def _discrete_constraints( ): """Generate discrete constraints.""" if parameter_names is None: - parameters = st.lists(st.text(), unique=True, min_size=1) + params = draw(st.lists(st.text(), unique=True, min_size=1)) else: assert len(parameter_names) > 0 assert len(parameter_names) == len(set(parameter_names)) - parameters = st.just(parameter_names) - - if constraint_type in [DiscreteSumConstraint, DiscreteProductConstraint]: - return st.builds(constraint_type, parameters, threshold_conditions()) + params = parameter_names + + if constraint_type is DiscreteSumConstraint: + condition = draw(threshold_conditions()) + if draw(st.booleans()): + coefficients = draw(st.tuples(*([_nonzero_finite_floats] * len(params)))) + return DiscreteSumConstraint(params, condition, coefficients) + return DiscreteSumConstraint(params, condition) + elif constraint_type is DiscreteProductConstraint: + return DiscreteProductConstraint(params, draw(threshold_conditions())) else: - return st.builds(constraint_type, parameters) + return constraint_type(params) discrete_sum_constraints = partial(_discrete_constraints, DiscreteSumConstraint) @@ -227,7 +238,7 @@ def continuous_linear_constraints( assert len(parameter_names) > 0 assert len(parameter_names) == len(set(parameter_names)) - coefficients = draw(st.tuples(*([finite_floats()] * len(parameter_names)))) + coefficients = draw(st.tuples(*([_nonzero_finite_floats] * len(parameter_names)))) rhs = draw(finite_floats()) is_interpoint = draw(st.booleans()) diff --git a/tests/serialization/test_campaign_serialization.py b/tests/serialization/test_campaign_serialization.py index e0eadfd5e2..0bb3b5c961 100644 --- a/tests/serialization/test_campaign_serialization.py +++ b/tests/serialization/test_campaign_serialization.py @@ -34,6 +34,8 @@ def test_valid_simplex_config(simplex_config): def test_invalid_simplex_config(simplex_config): - simplex_config = simplex_config.replace("0.0, ", "-1.0, 0.0, ") + simplex_config = simplex_config.replace( + '"max_sum": 1.0', '"simplex_coefficients": [1.0, 0.0], "max_sum": 1.0' + ) with pytest.raises(ClassValidationError): Campaign.validate_config(simplex_config) diff --git a/tests/validation/test_constraint_validation.py b/tests/validation/test_constraint_validation.py index 2bee6bdd8f..db0aec3876 100644 --- a/tests/validation/test_constraint_validation.py +++ b/tests/validation/test_constraint_validation.py @@ -3,7 +3,12 @@ import pytest from pytest import param -from baybe.constraints.continuous import ContinuousCardinalityConstraint +from baybe.constraints.conditions import ThresholdCondition +from baybe.constraints.continuous import ( + ContinuousCardinalityConstraint, + ContinuousLinearConstraint, +) +from baybe.constraints.discrete import DiscreteSumConstraint @pytest.mark.parametrize( @@ -21,3 +26,26 @@ def test_invalid_cardinalities(cardinalities, error, match): """Providing an invalid parameter name raises an exception.""" with pytest.raises(error, match=match): ContinuousCardinalityConstraint(["x", "y"], *cardinalities) + + +@pytest.mark.parametrize( + ("coefficients", "match"), + [ + param((1.0, 2.0), "'coefficients' list must have one", id="length-mismatch"), + param((1.0, 0.0, 1.0), "'coefficients' must be non-zero", id="zero-coeff"), + ], +) +def test_invalid_coefficients(coefficients, match): + """Invalid coefficients raise a ValueError.""" + with pytest.raises(ValueError, match=match): + DiscreteSumConstraint( + parameters=["A", "B", "C"], + condition=ThresholdCondition(threshold=1.0, operator="<="), + coefficients=coefficients, + ) + with pytest.raises(ValueError, match=match): + ContinuousLinearConstraint( + parameters=["A", "B", "C"], + operator="<=", + coefficients=coefficients, + )