From fb90ff559b97bb9874777683d3386cb276aa69df Mon Sep 17 00:00:00 2001 From: samay2504 Date: Sat, 11 Apr 2026 22:39:43 +0530 Subject: [PATCH] fix: support mixed bool and float in maxp --- spopt/region/maxp.py | 2 +- spopt/tests/test_region/test_maxp.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/spopt/region/maxp.py b/spopt/region/maxp.py index ffd15024..d7bde74a 100644 --- a/spopt/region/maxp.py +++ b/spopt/region/maxp.py @@ -90,7 +90,7 @@ def maxp( """ gdf, w = modify_components(gdf, w, threshold_name, threshold, policy=policy) - attr = np.atleast_2d(gdf[attrs_name].values) + attr = np.atleast_2d(gdf[attrs_name].to_numpy(dtype=float)) if attr.shape[0] == 1: attr = attr.T threshold_array = gdf[threshold_name].values diff --git a/spopt/tests/test_region/test_maxp.py b/spopt/tests/test_region/test_maxp.py index 62ce571f..790d1707 100644 --- a/spopt/tests/test_region/test_maxp.py +++ b/spopt/tests/test_region/test_maxp.py @@ -111,6 +111,22 @@ def test_maxp_one_var(self): numpy.testing.assert_array_equal(model.labels_, self.var1_labels) + def test_maxp_mixed_bool_float_attrs(self): + self.mexico["high_gdp"] = self.mexico["PCGDP2000"] > self.mexico[ + "PCGDP2000" + ].median() + attrs_name = ["high_gdp", "PCGDP2000"] + threshold = 4 + top_n = 2 + threshold_name = "count" + numpy.random.seed(123456) + model = MaxPHeuristic( + self.mexico, self.w, attrs_name, threshold_name, threshold, top_n + ) + model.solve() + + assert len(model.labels_) == len(self.mexico) + def test_infeasible_components(self): ifcs = infeasible_components(self.mexico, self.w, "count", 35) numpy.testing.assert_array_equal(ifcs, [0])