From 04aa14d6632119ffba01b21410f3b5162f85d1bd Mon Sep 17 00:00:00 2001 From: Grain Team Date: Wed, 13 May 2026 16:06:09 -0700 Subject: [PATCH] bug fix PiperOrigin-RevId: 915109193 --- .../python/experimental/index_shuffle/index_shuffle.cc | 5 +++-- .../index_shuffle/python/index_shuffle_module.cc | 9 +++++---- .../index_shuffle/python/index_shuffle_test.py | 9 ++++++++- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/grain/_src/python/experimental/index_shuffle/index_shuffle.cc b/grain/_src/python/experimental/index_shuffle/index_shuffle.cc index 4fc584c26..622dfc863 100644 --- a/grain/_src/python/experimental/index_shuffle/index_shuffle.cc +++ b/grain/_src/python/experimental/index_shuffle/index_shuffle.cc @@ -58,7 +58,8 @@ namespace impl { // Returns the keys per round for the cipher. // This is not the proposed key schedule for a Simon cipher but simply uses // std::seed_seq. We found that this gives better results. -std::vector generate_keys(const uint32_t seed, const int32_t rounds) { +std::vector generate_keys(const uint32_t seed, + const uint32_t rounds) { std::vector rk(rounds); std::seed_seq seq{seed}; seq.generate(rk.begin(), rk.end()); @@ -118,7 +119,7 @@ uint64_t index_shuffle(const uint64_t index, const uint64_t max_index, block_size = std::max(block_size + block_size % 2, kMinBlockSize); assert(block_size > 0 && block_size % 2 == 0 && block_size <= 64); // At least 4 rounds and number of rounds must be even. - assert(rounds >= 4 && rounds % 2 == 0); + assert(rounds >= 4 && rounds % 2 == 0 && rounds <= 1024); // Assert the index is bounded by [0, max_index]. assert(index >= 0 && index <= max_index); #define HANDLE_BLOCK_SIZE(B) \ diff --git a/grain/_src/python/experimental/index_shuffle/python/index_shuffle_module.cc b/grain/_src/python/experimental/index_shuffle/python/index_shuffle_module.cc index 6e6f3d98c..9e47d2806 100644 --- a/grain/_src/python/experimental/index_shuffle/python/index_shuffle_module.cc +++ b/grain/_src/python/experimental/index_shuffle/python/index_shuffle_module.cc @@ -15,10 +15,11 @@ PYBIND11_MODULE(index_shuffle_module, m) { m.def( "index_shuffle", [](int64_t index, int64_t max_index, uint32_t seed, uint32_t rounds) { - if (rounds < 4 || rounds % 2 != 0) { - throw py::value_error(absl::StrCat( - "rounds must be an even integer >= 4, but got rounds = ", - rounds)); + if (rounds < 4 || rounds % 2 != 0 || rounds > 1024) { + throw py::value_error( + absl::StrCat("rounds must be an even integer between 4 and 1024, " + "but got rounds = ", + rounds)); } if (index < 0 || index > max_index) { throw py::value_error(absl::StrCat( diff --git a/grain/_src/python/experimental/index_shuffle/python/index_shuffle_test.py b/grain/_src/python/experimental/index_shuffle/python/index_shuffle_test.py index e6def6c0b..cd1b13f63 100644 --- a/grain/_src/python/experimental/index_shuffle/python/index_shuffle_test.py +++ b/grain/_src/python/experimental/index_shuffle/python/index_shuffle_test.py @@ -43,12 +43,19 @@ def test_index_shuffle_single_record(self): ) def test_index_shuffle_invalid_rounds(self): - regex = r'rounds must be an even integer >= 4' + regex = r'rounds must be an even integer between 4 and 1024' with self.assertRaisesRegex(ValueError, regex): index_shuffle.index_shuffle(index=0, max_index=8, seed=33, rounds=2) with self.assertRaisesRegex(ValueError, regex): index_shuffle.index_shuffle(index=0, max_index=8, seed=76, rounds=5) + def test_index_shuffle_rounds_too_large(self): + regex = r'rounds must be an even integer between 4 and 1024' + with self.assertRaisesRegex(ValueError, regex): + index_shuffle.index_shuffle( + index=0, max_index=1, seed=0, rounds=4294967294 + ) + def test_index_shuffle_invalid_index(self): regex = r'index must be in \[0, max_index\]' with self.assertRaisesRegex(ValueError, regex):