diff --git a/pyproject.toml b/pyproject.toml index 9c08bc7..b491358 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,8 +6,9 @@ build-backend = "setuptools.build_meta" name = "lemke" version = "0.0.1" dependencies = [ - "numpy>=2.2,<2.3", - "matplotlib>=3.10,<3.11" + "numpy>=2.2", + "matplotlib>=3.10", + "click>=8.1", ] requires-python = ">=3.10" @@ -43,6 +44,7 @@ Repository = "https://github.com/gambitproject/lemke.git" [project.scripts] lemke = "lemke.lemke:main" bimatrix = "lemke.bimatrix:main" +randomstart = "lemke.randomstart:main" [tool.setuptools.packages.find] where = ["src"] diff --git a/src/lemke/randomstart.py b/src/lemke/randomstart.py index 49e23df..c7aca63 100644 --- a/src/lemke/randomstart.py +++ b/src/lemke/randomstart.py @@ -1,9 +1,11 @@ import fractions import random -import sys +import click import matplotlib.pyplot as plt +MAX_ACCURACY = 10_000_000 + # give random n-tuple uniformly from unit simplex def randInSimplex(n, naive=False): @@ -31,6 +33,9 @@ def randInSimplex(n, naive=False): # round an array of probabilities to fractions with # denominator def roundArray(x, accuracy=10000): + if not 1 <= accuracy <= MAX_ACCURACY: + raise ValueError(f"accuracy must be between 1 and {MAX_ACCURACY}") + n = len(x) sum = 0 numerator = [0] * n @@ -68,24 +73,37 @@ def maptotriangle(vec): return x, y -def main(): - arglist = sys.argv - print("Usage: ", arglist[0], - "[numpoints [accuracy [higherdim ['n[aive]']]]]") - numpoints = 200 # number of points plotted - accuracy = 20 # coarse accuracy - higherdim = 3 # display middle 3 dimensions - naiveplot = False # if True just sum random numbers - if len(arglist) > 1: - numpoints = int(arglist[1]) - if len(arglist) > 2: - accuracy = int(arglist[2]) - if len(arglist) > 3: - a = int(arglist[3]) - if 2 < a < 11: - higherdim = a - if len(arglist) > 4: - naiveplot = True +@click.command( + context_settings={"help_option_names": ["-?", "-h", "--help"]}, +) +@click.option( + "--numpoints", + default=200, + show_default=True, + help="Number of points plotted", +) +@click.option( + "--accuracy", + default=20, + show_default=True, + help="Denominator x: each coordinate is rounded to the nearest multiple of 1/x", + type=click.IntRange(1, MAX_ACCURACY), + metavar="INTEGER", +) +@click.option( + "--higherdim", + default=3, + show_default=True, + help="Dimension from which the middle 3 components will be sampled", + type=click.IntRange(3, 10), + metavar="INTEGER", +) +@click.option( + "--naiveplot", + is_flag=True, + help="Sample naively by normalizing random uniforms (biased toward center)", +) +def main(numpoints, accuracy, higherdim, naiveplot): print( f"numpoints={numpoints} accuracy={accuracy} higherdim={higherdim} naiveplot={naiveplot}" ) diff --git a/tests/test_randomstart.py b/tests/test_randomstart.py new file mode 100644 index 0000000..9a2eb6d --- /dev/null +++ b/tests/test_randomstart.py @@ -0,0 +1,145 @@ +import fractions +import math +from unittest.mock import patch + +import matplotlib +import pytest +from click.testing import CliRunner + +from lemke.randomstart import ( + MAX_ACCURACY, + main, + maptotriangle, + randInSimplex, + renormalize, + roundArray, +) + +matplotlib.use("Agg") + + +@pytest.mark.parametrize("n", [2, 3, 5, 20]) +@pytest.mark.parametrize("naive", [True, False]) +class TestRandInSimplex: + def test_output_length(self, n, naive): + result = randInSimplex(n, naive) + assert len(result) == n + + def test_sum_is_approx_one(self, n, naive): + result = randInSimplex(n, naive) + assert sum(result) == pytest.approx(1.0) + + def test_components_in_range(self, n, naive): + result = randInSimplex(n, naive) + assert all(0.0 <= x <= 1.0 for x in result) + + +@pytest.mark.parametrize( + "array", + [ + [1.0, 0.0], + [0.3333, 0.3333, 0.3334], + [0.1, 0.2, 0.3, 0.4], + [0.0, 0.5, 0.5], + [0.111111111111111] * 9, + ], +) +@pytest.mark.parametrize( + "accuracy", + [ + 10, + 100, + 10000, + MAX_ACCURACY, + ], +) +class TestRoundArraySuccess: + def test_output_length(self, array, accuracy): + result = roundArray(array, accuracy) + assert len(result) == len(array) + + def test_sum_is_exactly_one(self, array, accuracy): + result = roundArray(array, accuracy) + assert sum(result) == fractions.Fraction(1, 1) + + def test_returns_fractions(self, array, accuracy): + result = roundArray(array, accuracy) + assert all(isinstance(x, fractions.Fraction) for x in result) + + def test_denominators_match_accuracy(self, array, accuracy): + result = roundArray(array, accuracy) + + # Check if requested accuracy is a multiple of the (possibly reduced) denominator + assert all(accuracy % x.denominator == 0 for x in result) + + +class TestRoundArrayFailure: + @pytest.mark.parametrize("bad_accuracy", [0, -1, MAX_ACCURACY + 1]) + def test_accuracy_out_of_bounds(self, bad_accuracy): + with pytest.raises(ValueError, match="accuracy must be between"): + roundArray([0.5, 0.5], accuracy=bad_accuracy) + + def test_invalid_probabilities(self): + with pytest.raises(ValueError, match="need probabilities"): + roundArray([1.0, 1.0]) + + +class TestRenormalize: + def test_all_zeros(self): + assert renormalize([0, 0, 0]) == [0, 0, 0] + + def test_single_element(self): + assert renormalize([42.0]) == [1.0] + + def test_already_normalized(self): + assert renormalize([0.2, 0.5, 0.3]) == pytest.approx([0.2, 0.5, 0.3]) + + def test_standard(self): + assert renormalize([1, 2, 3, 4]) == pytest.approx([0.1, 0.2, 0.3, 0.4]) + + +class TestMapToTriangle: + def test_vertices(self): + assert maptotriangle([1, 0, 0]) == pytest.approx((0.0, 0.0)) + assert maptotriangle([0, 1, 0]) == pytest.approx((1.0, 0.0)) + assert maptotriangle([0, 0, 1]) == pytest.approx((0.5, math.sqrt(3) / 2)) + + @pytest.mark.parametrize( + "vec, expected", + [ + ([1/3, 1/3, 1/3], (0.5, math.sqrt(3) / 6)), + ([0.5, 0.5, 0.0], (0.5, 0.0)), + ([0.0, 0.5, 0.5], (0.75, math.sqrt(3) / 4)), + ], + ) + def test_known_points(self, vec, expected): + assert maptotriangle(vec) == pytest.approx(expected) + + +class TestCLI: + @pytest.mark.parametrize( + "arguments", + [ + [], + ["--numpoints", "10", "--accuracy", "100", "--higherdim", "7", "--naiveplot"], + ] + ) + def test_cli_runs_without_error(self, arguments): + runner = CliRunner() + with patch("matplotlib.pyplot.show"): + result = runner.invoke(main, arguments) + assert result.exit_code == 0 + + @pytest.mark.parametrize("higherdim", ["-1", "0", "2", "20"]) + def test_cli_invalid_higherdim(self, higherdim): + runner = CliRunner() + result = runner.invoke(main, ["--higherdim", higherdim]) + assert result.exit_code != 0 + assert "Invalid value" in result.output + + @pytest.mark.parametrize("accuracy", ["-1", "0", f"{MAX_ACCURACY + 1}"]) + def test_cli_invalid_accuracy(self, accuracy): + runner = CliRunner() + result = runner.invoke(main, ["--accuracy", accuracy]) + assert result.exit_code != 0 + assert "Invalid value" in result.output