From 1af9a8e2ae1d0b6d68362ecf2570654f929410b3 Mon Sep 17 00:00:00 2001 From: sergioballesteros-vd <221996264+sergioballesteros-vd@users.noreply.github.com> Date: Fri, 19 Jun 2026 15:05:15 +0200 Subject: [PATCH] fix: preserve remainder rows across worker batches --- CHANGELOG.md | 3 +++ src/gen_fraud_graph/generator.py | 36 +++++++++++++++++++++++++------- tests/test_generator.py | 13 +++++++++++- 3 files changed, 43 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c495189..c7a87a9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Fixed +- Preserve all generated account and transaction rows when the requested totals do not divide evenly across worker batches. + ## [0.1.0] - 2026-05-26 ### Added diff --git a/src/gen_fraud_graph/generator.py b/src/gen_fraud_graph/generator.py index a2dd617..b49096d 100644 --- a/src/gen_fraud_graph/generator.py +++ b/src/gen_fraud_graph/generator.py @@ -36,6 +36,28 @@ ] +# --------------------------------------------------------------------------- +# Workload planning helpers +# --------------------------------------------------------------------------- + + +def _split_workload(total: int, num_shards: int) -> list[tuple[int, int]]: + """Split ``total`` rows across ``num_shards`` shards without dropping rows.""" + if num_shards <= 0: + raise ValueError("num_shards must be greater than zero") + + base, remainder = divmod(total, num_shards) + shards: list[tuple[int, int]] = [] + start = 0 + + for shard_idx in range(num_shards): + count = base + (1 if shard_idx < remainder else 0) + shards.append((start, count)) + start += count + + return shards + + # --------------------------------------------------------------------------- # Worker functions (must be top-level for multiprocessing) # --------------------------------------------------------------------------- @@ -272,22 +294,21 @@ def _generate_accounts(self) -> None: cfg = self.cfg print("\n[Phase 1] Generating accounts...") - acc_per_worker = cfg.num_accounts // cfg.workers - acc_per_batch = acc_per_worker // cfg.batches_per_worker + shard_plan = _split_workload(cfg.num_accounts, cfg.workers * cfg.batches_per_worker) with ProcessPoolExecutor(max_workers=cfg.workers) as pool: futures = [] for w in range(cfg.workers): for b in range(cfg.batches_per_worker): global_idx = w * cfg.batches_per_worker + b - start_id = global_idx * acc_per_batch + start_id, count = shard_plan[global_idx] futures.append( pool.submit( _generate_accounts_chunk, w, b, start_id, - acc_per_batch, + count, cfg.embedding_provider, cfg.embedding_dim, cfg.output_dir, @@ -301,22 +322,21 @@ def _generate_transactions(self) -> None: cfg = self.cfg print("\n[Phase 2] Generating transactions...") - tx_per_worker = cfg.num_transactions // cfg.workers - tx_per_batch = tx_per_worker // cfg.batches_per_worker + shard_plan = _split_workload(cfg.num_transactions, cfg.workers * cfg.batches_per_worker) with ProcessPoolExecutor(max_workers=cfg.workers) as pool: futures = [] for w in range(cfg.workers): for b in range(cfg.batches_per_worker): global_idx = w * cfg.batches_per_worker + b - start_id = global_idx * tx_per_batch + start_id, count = shard_plan[global_idx] futures.append( pool.submit( _generate_transactions_chunk, w, b, start_id, - tx_per_batch, + count, cfg.num_accounts, cfg.embedding_provider, cfg.embedding_dim, diff --git a/tests/test_generator.py b/tests/test_generator.py index d3a1353..8bad76a 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -15,7 +15,7 @@ from gen_fraud_graph.config import Config from gen_fraud_graph.embeddings import EmbeddingGenerator from gen_fraud_graph.exporters import get_headers, write_output -from gen_fraud_graph.generator import FraudGraphGenerator +from gen_fraud_graph.generator import FraudGraphGenerator, _split_workload from gen_fraud_graph.typologies import FraudRingGenerator from gen_fraud_graph.verify import verify_fraud_patterns @@ -73,6 +73,17 @@ def test_tiny_scale(self): assert cfg.num_fraud_rings >= 10 +class TestWorkloadPlanning: + def test_split_workload_distributes_remainder(self): + shards = _split_workload(10, 3) + assert shards == [(0, 4), (4, 3), (7, 3)] + assert sum(count for _, count in shards) == 10 + + def test_split_workload_handles_exact_division(self): + shards = _split_workload(12, 4) + assert shards == [(0, 3), (3, 3), (6, 3), (9, 3)] + + # --------------------------------------------------------------------------- # Embedding tests # ---------------------------------------------------------------------------