Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ build-backend = "setuptools.build_meta"
name = "lemke"
version = "0.0.1"
dependencies = [
"numpy>=2.2,<2.3",
"matplotlib>=3.10,<3.11"
"numpy>=2.2",
"matplotlib>=3.10",
"click>=8.1",
]
Comment on lines 8 to 12

requires-python = ">=3.10"
Expand Down Expand Up @@ -43,6 +44,7 @@ Repository = "https://github.com/gambitproject/lemke.git"
[project.scripts]
lemke = "lemke.lemke:main"
bimatrix = "lemke.bimatrix:main"
randomstart = "lemke.randomstart:main"

[tool.setuptools.packages.find]
where = ["src"]
Expand Down
56 changes: 37 additions & 19 deletions src/lemke/randomstart.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import fractions
import random
import sys

import click
import matplotlib.pyplot as plt

MAX_ACCURACY = 10_000_000


# give random n-tuple uniformly from unit simplex
def randInSimplex(n, naive=False):
Expand Down Expand Up @@ -31,6 +33,9 @@ def randInSimplex(n, naive=False):
# round an array <x> of probabilities to fractions with
# denominator <accuracy>
def roundArray(x, accuracy=10000):
if not 1 <= accuracy <= MAX_ACCURACY:
raise ValueError(f"accuracy must be between 1 and {MAX_ACCURACY}")
Comment on lines 35 to +37

n = len(x)
sum = 0
numerator = [0] * n
Expand Down Expand Up @@ -68,24 +73,37 @@ def maptotriangle(vec):
return x, y


def main():
arglist = sys.argv
print("Usage: ", arglist[0],
"[numpoints [accuracy [higherdim ['n[aive]']]]]")
numpoints = 200 # number of points plotted
accuracy = 20 # coarse accuracy
higherdim = 3 # display middle 3 dimensions
naiveplot = False # if True just sum random numbers
if len(arglist) > 1:
numpoints = int(arglist[1])
if len(arglist) > 2:
accuracy = int(arglist[2])
if len(arglist) > 3:
a = int(arglist[3])
if 2 < a < 11:
higherdim = a
if len(arglist) > 4:
naiveplot = True
@click.command(
context_settings={"help_option_names": ["-?", "-h", "--help"]},
)
@click.option(
"--numpoints",
default=200,
show_default=True,
help="Number of points plotted",
)
@click.option(
"--accuracy",
default=20,
show_default=True,
help="Denominator x: each coordinate is rounded to the nearest multiple of 1/x",
type=click.IntRange(1, MAX_ACCURACY),
metavar="INTEGER",
)
@click.option(
"--higherdim",
default=3,
show_default=True,
help="Dimension from which the middle 3 components will be sampled",
type=click.IntRange(3, 10),
metavar="INTEGER",
)
@click.option(
"--naiveplot",
is_flag=True,
help="Sample naively by normalizing random uniforms (biased toward center)",
)
def main(numpoints, accuracy, higherdim, naiveplot):
print(
f"numpoints={numpoints} accuracy={accuracy} higherdim={higherdim} naiveplot={naiveplot}"
)
Expand Down
145 changes: 145 additions & 0 deletions tests/test_randomstart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import fractions
import math
from unittest.mock import patch

import matplotlib
import pytest
from click.testing import CliRunner

from lemke.randomstart import (
MAX_ACCURACY,
main,
maptotriangle,
randInSimplex,
renormalize,
roundArray,
)

matplotlib.use("Agg")


@pytest.mark.parametrize("n", [2, 3, 5, 20])
@pytest.mark.parametrize("naive", [True, False])
class TestRandInSimplex:
def test_output_length(self, n, naive):
result = randInSimplex(n, naive)
assert len(result) == n

def test_sum_is_approx_one(self, n, naive):
result = randInSimplex(n, naive)
assert sum(result) == pytest.approx(1.0)

def test_components_in_range(self, n, naive):
result = randInSimplex(n, naive)
assert all(0.0 <= x <= 1.0 for x in result)


@pytest.mark.parametrize(
"array",
[
[1.0, 0.0],
[0.3333, 0.3333, 0.3334],
[0.1, 0.2, 0.3, 0.4],
[0.0, 0.5, 0.5],
[0.111111111111111] * 9,
],
)
@pytest.mark.parametrize(
"accuracy",
[
10,
100,
10000,
MAX_ACCURACY,
],
)
class TestRoundArraySuccess:
def test_output_length(self, array, accuracy):
result = roundArray(array, accuracy)
assert len(result) == len(array)

def test_sum_is_exactly_one(self, array, accuracy):
result = roundArray(array, accuracy)
assert sum(result) == fractions.Fraction(1, 1)

def test_returns_fractions(self, array, accuracy):
result = roundArray(array, accuracy)
assert all(isinstance(x, fractions.Fraction) for x in result)

def test_denominators_match_accuracy(self, array, accuracy):
result = roundArray(array, accuracy)

# Check if requested accuracy is a multiple of the (possibly reduced) denominator
assert all(accuracy % x.denominator == 0 for x in result)


class TestRoundArrayFailure:
@pytest.mark.parametrize("bad_accuracy", [0, -1, MAX_ACCURACY + 1])
def test_accuracy_out_of_bounds(self, bad_accuracy):
with pytest.raises(ValueError, match="accuracy must be between"):
roundArray([0.5, 0.5], accuracy=bad_accuracy)

def test_invalid_probabilities(self):
with pytest.raises(ValueError, match="need probabilities"):
roundArray([1.0, 1.0])


class TestRenormalize:
def test_all_zeros(self):
assert renormalize([0, 0, 0]) == [0, 0, 0]

def test_single_element(self):
assert renormalize([42.0]) == [1.0]

def test_already_normalized(self):
assert renormalize([0.2, 0.5, 0.3]) == pytest.approx([0.2, 0.5, 0.3])

def test_standard(self):
assert renormalize([1, 2, 3, 4]) == pytest.approx([0.1, 0.2, 0.3, 0.4])


class TestMapToTriangle:
def test_vertices(self):
assert maptotriangle([1, 0, 0]) == pytest.approx((0.0, 0.0))
assert maptotriangle([0, 1, 0]) == pytest.approx((1.0, 0.0))
assert maptotriangle([0, 0, 1]) == pytest.approx((0.5, math.sqrt(3) / 2))

@pytest.mark.parametrize(
"vec, expected",
[
([1/3, 1/3, 1/3], (0.5, math.sqrt(3) / 6)),
([0.5, 0.5, 0.0], (0.5, 0.0)),
([0.0, 0.5, 0.5], (0.75, math.sqrt(3) / 4)),
],
)
def test_known_points(self, vec, expected):
assert maptotriangle(vec) == pytest.approx(expected)


class TestCLI:
@pytest.mark.parametrize(
"arguments",
[
[],
["--numpoints", "10", "--accuracy", "100", "--higherdim", "7", "--naiveplot"],
]
)
def test_cli_runs_without_error(self, arguments):
runner = CliRunner()
with patch("matplotlib.pyplot.show"):
result = runner.invoke(main, arguments)
assert result.exit_code == 0

@pytest.mark.parametrize("higherdim", ["-1", "0", "2", "20"])
def test_cli_invalid_higherdim(self, higherdim):
runner = CliRunner()
result = runner.invoke(main, ["--higherdim", higherdim])
assert result.exit_code != 0
assert "Invalid value" in result.output

@pytest.mark.parametrize("accuracy", ["-1", "0", f"{MAX_ACCURACY + 1}"])
def test_cli_invalid_accuracy(self, accuracy):
runner = CliRunner()
result = runner.invoke(main, ["--accuracy", accuracy])
assert result.exit_code != 0
assert "Invalid value" in result.output
Loading