Skip to content
Open

bug fix #1312

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions grain/_src/python/experimental/index_shuffle/index_shuffle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ namespace impl {
// Returns the keys per round for the cipher.
// This is not the proposed key schedule for a Simon cipher but simply uses
// std::seed_seq. We found that this gives better results.
std::vector<uint32_t> generate_keys(const uint32_t seed, const int32_t rounds) {
std::vector<uint32_t> generate_keys(const uint32_t seed,
const uint32_t rounds) {
std::vector<uint32_t> rk(rounds);
std::seed_seq seq{seed};
seq.generate(rk.begin(), rk.end());
Expand Down Expand Up @@ -118,7 +119,7 @@ uint64_t index_shuffle(const uint64_t index, const uint64_t max_index,
block_size = std::max(block_size + block_size % 2, kMinBlockSize);
assert(block_size > 0 && block_size % 2 == 0 && block_size <= 64);
// At least 4 rounds and number of rounds must be even.
assert(rounds >= 4 && rounds % 2 == 0);
assert(rounds >= 4 && rounds % 2 == 0 && rounds <= 1024);
// Assert the index is bounded by [0, max_index].
assert(index >= 0 && index <= max_index);
#define HANDLE_BLOCK_SIZE(B) \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ PYBIND11_MODULE(index_shuffle_module, m) {
m.def(
"index_shuffle",
[](int64_t index, int64_t max_index, uint32_t seed, uint32_t rounds) {
if (rounds < 4 || rounds % 2 != 0) {
throw py::value_error(absl::StrCat(
"rounds must be an even integer >= 4, but got rounds = ",
rounds));
if (rounds < 4 || rounds % 2 != 0 || rounds > 1024) {
throw py::value_error(
absl::StrCat("rounds must be an even integer between 4 and 1024, "
"but got rounds = ",
rounds));
}
if (index < 0 || index > max_index) {
throw py::value_error(absl::StrCat(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,19 @@ def test_index_shuffle_single_record(self):
)

def test_index_shuffle_invalid_rounds(self):
regex = r'rounds must be an even integer >= 4'
regex = r'rounds must be an even integer between 4 and 1024'
with self.assertRaisesRegex(ValueError, regex):
index_shuffle.index_shuffle(index=0, max_index=8, seed=33, rounds=2)
with self.assertRaisesRegex(ValueError, regex):
index_shuffle.index_shuffle(index=0, max_index=8, seed=76, rounds=5)

def test_index_shuffle_rounds_too_large(self):
regex = r'rounds must be an even integer between 4 and 1024'
with self.assertRaisesRegex(ValueError, regex):
index_shuffle.index_shuffle(
index=0, max_index=1, seed=0, rounds=4294967294
)

def test_index_shuffle_invalid_index(self):
regex = r'index must be in \[0, max_index\]'
with self.assertRaisesRegex(ValueError, regex):
Expand Down
Loading