Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Fixed
- Preserve all generated account and transaction rows when the requested totals do not divide evenly across worker batches.

## [0.1.0] - 2026-05-26

### Added
Expand Down
36 changes: 28 additions & 8 deletions src/gen_fraud_graph/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,28 @@
]


# ---------------------------------------------------------------------------
# Workload planning helpers
# ---------------------------------------------------------------------------


def _split_workload(total: int, num_shards: int) -> list[tuple[int, int]]:
"""Split ``total`` rows across ``num_shards`` shards without dropping rows."""
if num_shards <= 0:
raise ValueError("num_shards must be greater than zero")

base, remainder = divmod(total, num_shards)
shards: list[tuple[int, int]] = []
start = 0

for shard_idx in range(num_shards):
count = base + (1 if shard_idx < remainder else 0)
shards.append((start, count))
start += count

return shards


# ---------------------------------------------------------------------------
# Worker functions (must be top-level for multiprocessing)
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -272,22 +294,21 @@ def _generate_accounts(self) -> None:
cfg = self.cfg
print("\n[Phase 1] Generating accounts...")

acc_per_worker = cfg.num_accounts // cfg.workers
acc_per_batch = acc_per_worker // cfg.batches_per_worker
shard_plan = _split_workload(cfg.num_accounts, cfg.workers * cfg.batches_per_worker)

with ProcessPoolExecutor(max_workers=cfg.workers) as pool:
futures = []
for w in range(cfg.workers):
for b in range(cfg.batches_per_worker):
global_idx = w * cfg.batches_per_worker + b
start_id = global_idx * acc_per_batch
start_id, count = shard_plan[global_idx]
futures.append(
pool.submit(
_generate_accounts_chunk,
w,
b,
start_id,
acc_per_batch,
count,
cfg.embedding_provider,
cfg.embedding_dim,
cfg.output_dir,
Expand All @@ -301,22 +322,21 @@ def _generate_transactions(self) -> None:
cfg = self.cfg
print("\n[Phase 2] Generating transactions...")

tx_per_worker = cfg.num_transactions // cfg.workers
tx_per_batch = tx_per_worker // cfg.batches_per_worker
shard_plan = _split_workload(cfg.num_transactions, cfg.workers * cfg.batches_per_worker)

with ProcessPoolExecutor(max_workers=cfg.workers) as pool:
futures = []
for w in range(cfg.workers):
for b in range(cfg.batches_per_worker):
global_idx = w * cfg.batches_per_worker + b
start_id = global_idx * tx_per_batch
start_id, count = shard_plan[global_idx]
futures.append(
pool.submit(
_generate_transactions_chunk,
w,
b,
start_id,
tx_per_batch,
count,
cfg.num_accounts,
cfg.embedding_provider,
cfg.embedding_dim,
Expand Down
13 changes: 12 additions & 1 deletion tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from gen_fraud_graph.config import Config
from gen_fraud_graph.embeddings import EmbeddingGenerator
from gen_fraud_graph.exporters import get_headers, write_output
from gen_fraud_graph.generator import FraudGraphGenerator
from gen_fraud_graph.generator import FraudGraphGenerator, _split_workload
from gen_fraud_graph.typologies import FraudRingGenerator
from gen_fraud_graph.verify import verify_fraud_patterns

Expand Down Expand Up @@ -73,6 +73,17 @@ def test_tiny_scale(self):
assert cfg.num_fraud_rings >= 10


class TestWorkloadPlanning:
def test_split_workload_distributes_remainder(self):
shards = _split_workload(10, 3)
assert shards == [(0, 4), (4, 3), (7, 3)]
assert sum(count for _, count in shards) == 10

def test_split_workload_handles_exact_division(self):
shards = _split_workload(12, 4)
assert shards == [(0, 3), (3, 3), (6, 3), (9, 3)]


# ---------------------------------------------------------------------------
# Embedding tests
# ---------------------------------------------------------------------------
Expand Down
Loading