diff --git a/DGraph/distributed/commInfo.py b/DGraph/distributed/commInfo.py index 0d5bd7e..edf36a3 100644 --- a/DGraph/distributed/commInfo.py +++ b/DGraph/distributed/commInfo.py @@ -47,23 +47,22 @@ def compute_halo_vertices( """ Computes halo vertices. Supports both homogeneous and bipartite/heterogeneous relations. """ - # Fallback for homogeneous graphs if dst_partitioning is None: dst_partitioning = src_partitioning src_rank = src_partitioning[edge_list[:, 0]] dst_rank = dst_partitioning[edge_list[:, 1]] - # Cross-rank mask: source is local, destination is remote - cross_mask = (src_rank == rank) & (dst_rank != rank) + # FIX: Cross-rank mask for pull model: source is remote, destination is local + cross_mask = (src_rank != rank) & (dst_rank == rank) - # Return unique destination vertex IDs from those edges - return torch.unique(edge_list[cross_mask, 1]) + # FIX: Return unique SOURCE vertex IDs from those edges + return torch.unique(edge_list[cross_mask, 0]) def compute_local_edge_list( global_edge_list: torch.Tensor, # [E, 2] - partitioning: torch.Tensor, # [V] + partitioning: torch.Tensor, # [V] (Acts as dst_partitioning) local_vertices_global: torch.Tensor, # [num_local] halo_vertices_global: torch.Tensor, # [num_halo] rank: int, @@ -72,8 +71,8 @@ def compute_local_edge_list( num_halo = halo_vertices_global.size(0) num_global = partitioning.size(0) - # Filter edges owned by this rank - local_edge_mask = partitioning[global_edge_list[:, 0]] == rank + # FIX: Filter edges where the DESTINATION is owned by this rank (Index 1) + local_edge_mask = partitioning[global_edge_list[:, 1]] == rank local_edges_global = global_edge_list[local_edge_mask] # Build inverse map: global_id -> local_idx via scatter diff --git a/experiments/GraphCast/data_utils/graphcast_graph.py b/experiments/GraphCast/data_utils/graphcast_graph.py index 3d592e7..1137559 100644 --- a/experiments/GraphCast/data_utils/graphcast_graph.py +++ b/experiments/GraphCast/data_utils/graphcast_graph.py @@ -57,29 +57,29 @@ class GraphCastTopology: @dataclass class DistributedGraphCastGraph: + # Distributed environment info rank: int world_size: int ranks_per_graph: int + + # Graph metadata mesh_level: int lat_lon_grid: Tensor + + # Mesh vertex features mesh_graph_node_features: Tensor mesh_graph_edge_features: Tensor - mesh_graph_node_rank_placement: Tensor - mesh_graph_edge_rank_placement: Tensor - mesh_graph_src_indices: Tensor - mesh_graph_dst_indices: Tensor - mesh_graph_src_rank_placement: Tensor - mesh_graph_dst_rank_placement: Tensor - grid_rank_placement: Tensor + + # Grid vertex features mesh2grid_graph_node_features: Tensor - mesh2grid_graph_edge_features: Tensor - mesh2grid_graph_edge_rank_placement: Tensor - mesh2grid_graph_src_indices: Tensor - mesh2grid_graph_dst_indices: Tensor grid2mesh_graph_node_features: Tensor + + # Mesh <--> Grid edge features + mesh2grid_graph_edge_features: Tensor grid2mesh_graph_edge_features: Tensor - grid2mesh_graph_src_indices: Tensor - grid2mesh_graph_dst_indices: Tensor + + # Distributed graph info + distributed_comm_patterns: GraphCastCommPatterns def build_graphcast_comm_patterns(graph: GraphCastTopology) -> GraphCastCommPatterns: @@ -110,7 +110,6 @@ def build_graphcast_comm_patterns(graph: GraphCastTopology) -> GraphCastCommPatt grid2mesh_cp = build_communication_pattern( global_edge_list=grid2mesh_edges, partitioning=mesh_part, - neighbor_partitioning=grid_part, rank=rank, world_size=world_size, ) @@ -129,7 +128,6 @@ def build_graphcast_comm_patterns(graph: GraphCastTopology) -> GraphCastCommPatt mesh_cp = build_communication_pattern( global_edge_list=mesh_edges, partitioning=mesh_part, - neighbor_partitioning=mesh_part, rank=rank, world_size=world_size, ) @@ -147,7 +145,6 @@ def build_graphcast_comm_patterns(graph: GraphCastTopology) -> GraphCastCommPatt mesh2grid_cp = build_communication_pattern( global_edge_list=mesh2grid_edges, partitioning=grid_part, - neighbor_partitioning=mesh_part, rank=rank, world_size=world_size, ) @@ -220,6 +217,47 @@ def get_mesh_graph_partition(mesh_level: int, world_size: int): mesh_vertex_rank_placement = torch.tensor(mesh_vertex_rank_placement) return mesh_vertex_rank_placement + @staticmethod + def get_grid_vertex_partition( + lat: int, + lon: int, + mesh_vertex_rank_placement: torch.Tensor, + grid2mesh_grid_src_indices: torch.Tensor, + grid2mesh_mesh_dst_indices: torch.Tensor, + mesh2grid_mesh_src_indices: torch.Tensor, + world_size: int, + ) -> torch.Tensor: + """Generate the partitioning of grid vertices to minimize cross-rank edges. + + For each grid vertex, counts how many of its connected mesh vertices + (via both grid2mesh and mesh2grid edges) live on each rank, then assigns + the grid vertex to the rank with the plurality of connections. + + mesh2grid grid destinations are implicit: grid vertex i owns edges + [3i, 3i+1, 3i+2] since create_mesh2grid_graph assigns exactly 3 + edges (one face's vertices) per grid vertex. + """ + num_grid = lat * lon + votes = torch.zeros(num_grid, world_size, dtype=torch.long) + + # --- grid2mesh contribution: grid vertex is src, mesh vertex is dst --- + g2m_ranks = mesh_vertex_rank_placement[grid2mesh_mesh_dst_indices.long()] + # Flatten (grid_vertex, rank) into a 1D index for scatter_add_ + g2m_flat_idx = grid2mesh_grid_src_indices.long() * world_size + g2m_ranks + votes.view(-1).scatter_add_(0, g2m_flat_idx, torch.ones_like(g2m_flat_idx)) + + # --- mesh2grid contribution: mesh vertex is src, grid vertex is dst --- + # Each grid vertex i has exactly 3 mesh2grid edges at positions [3i, 3i+1, 3i+2] + m2g_grid_dst = torch.arange(num_grid, dtype=torch.long).repeat_interleave(3) + m2g_ranks = mesh_vertex_rank_placement[mesh2grid_mesh_src_indices.long()] + m2g_flat_idx = m2g_grid_dst * world_size + m2g_ranks + votes.view(-1).scatter_add_(0, m2g_flat_idx, torch.ones_like(m2g_flat_idx)) + + # Assign each grid vertex to the rank with the most connections + grid_partitioning = votes.argmax(dim=1) + + return grid_partitioning + def get_mesh_graph(self, mesh_vertex_rank_placement: torch.Tensor): """Get the graph for the distributed graphcast graph.""" @@ -327,8 +365,6 @@ def get_grid2mesh_graph(self, mesh_graph_dict: dict): grid_vertex_rank_placement ) - # TODO: Consider we can have it to so grid2mesh edges don't require a - # backpropagation. (If encoder is after the gather / scatter) grid2mesh_graph_dict = { "node_features": torch.tensor([]), "edge_features": edge_features, @@ -340,7 +376,7 @@ def get_grid2mesh_graph(self, mesh_graph_dict: dict): } return grid2mesh_graph_dict - def get_mesh2grid_graph( + def get_mesh2grid_edges( self, grid_vertex_rank_placement, renumbered_vertices, @@ -359,7 +395,6 @@ def get_mesh2grid_graph( mesh2grid_edge_rank_placement = grid_vertex_rank_placement[dst_grid_indices] mesh2grid_graph_dict = { - "node_features": torch.tensor([]), "edge_features": edge_features, "src_indices": src_mesh_indices, "dst_indices": dst_grid_indices, @@ -401,10 +436,26 @@ def get_graphcast_graph( grid_vertex_rank_placement = grid2mesh_graph["grid_vertex_rank_placement"] renumbered_grid = grid2mesh_graph["renumbered_grid"] - mesh2grid_graph = self.get_mesh2grid_graph( + mesh2grid_graph = self.get_mesh2grid_edges( grid_vertex_rank_placement, renumbered_vertices, renumbered_grid ) + topology = GraphCastTopology( + rank=self.local_rank, + world_size=self.world_size, + ranks_per_graph=self.ranks_per_graph, + mesh_rank_placement=mesh_vertex_rank_placement, + grid_rank_placement=grid_vertex_rank_placement, + mesh_graph_src_indices=mesh_graph["src_indices"], + mesh_graph_dst_indices=mesh_graph["dst_indices"], + mesh2grid_graph_src_indices=mesh2grid_graph["src_indices"], + mesh2grid_graph_dst_indices=mesh2grid_graph["dst_indices"], + grid2mesh_graph_src_indices=grid2mesh_graph["src_indices"], + grid2mesh_graph_dst_indices=grid2mesh_graph["dst_indices"], + ) + + comm_patterns = build_graphcast_comm_patterns(topology) + return DistributedGraphCastGraph( rank=self.rank, world_size=self.world_size, @@ -413,25 +464,9 @@ def get_graphcast_graph( lat_lon_grid=self.lat_lon_grid, mesh_graph_node_features=mesh_graph["node_features"], mesh_graph_edge_features=mesh_graph["edge_features"], - mesh_graph_node_rank_placement=mesh_graph["node_rank_placement"], - mesh_graph_edge_rank_placement=mesh_graph["edge_rank_placement"], - mesh_graph_src_indices=mesh_graph["src_indices"], - mesh_graph_dst_indices=mesh_graph["dst_indices"], - mesh_graph_src_rank_placement=mesh_graph["src_rank_placement"], - mesh_graph_dst_rank_placement=mesh_graph["dst_rank_placement"], - grid_rank_placement=grid2mesh_graph["grid_vertex_rank_placement"], - mesh2grid_graph_node_features=mesh2grid_graph["node_features"], - mesh2grid_graph_edge_features=mesh2grid_graph["edge_features"], - mesh2grid_graph_edge_rank_placement=mesh2grid_graph[ - "mesh2grid_edge_rank_placement" - ], - mesh2grid_graph_src_indices=mesh2grid_graph["src_indices"], - mesh2grid_graph_dst_indices=mesh2grid_graph["dst_indices"], + mesh2grid_graph_node_features=torch.tensor([]), grid2mesh_graph_node_features=grid2mesh_graph["node_features"], + mesh2grid_graph_edge_features=mesh2grid_graph["edge_features"], grid2mesh_graph_edge_features=grid2mesh_graph["edge_features"], - grid2mesh_graph_edge_rank_placement=grid2mesh_graph[ - "grid2mesh_edge_rank_placement" - ], - grid2mesh_graph_src_indices=grid2mesh_graph["src_indices"], - grid2mesh_graph_dst_indices=grid2mesh_graph["dst_indices"], + distributed_comm_patterns=comm_patterns, ) diff --git a/experiments/GraphCast/dataset.py b/experiments/GraphCast/dataset.py index 3e4eb08..b15e700 100644 --- a/experiments/GraphCast/dataset.py +++ b/experiments/GraphCast/dataset.py @@ -261,20 +261,14 @@ def test_synthetic_weather_dataset(num_days, batch_size=1): print("Mesh label:\t", static_graph.mesh_level) print("Mesh Node features:\t", static_graph.mesh_graph_node_features.shape) print("Mesh Edge features:\t", static_graph.mesh_graph_edge_features.shape) - print("Mesh src indices:\t", static_graph.mesh_graph_src_indices.shape) - print("Mesh dst indices:\t", static_graph.mesh_graph_dst_indices.shape) print("=" * 80) print( "mesh2grid edge features:\t", static_graph.mesh2grid_graph_edge_features.shape ) - print("mesh2grid src indices:\t", static_graph.mesh2grid_graph_src_indices.shape) - print("mesh2grid dst indices:\t", static_graph.mesh2grid_graph_dst_indices.shape) print("=" * 80) print( "grid2mesh edge features:\t", static_graph.grid2mesh_graph_edge_features.shape ) - print("grid2mesh src indices:\t", static_graph.grid2mesh_graph_src_indices.shape) - print("grid2mesh dst indices:\t", static_graph.grid2mesh_graph_dst_indices.shape) print("=" * 80) diff --git a/experiments/GraphCast/layers.py b/experiments/GraphCast/layers.py index 524987b..027974b 100644 --- a/experiments/GraphCast/layers.py +++ b/experiments/GraphCast/layers.py @@ -11,14 +11,12 @@ # https://github.com/LBANN and https://github.com/LLNL/LBANN. # # SPDX-License-Identifier: (Apache-2.0) - -from typing import Tuple, Union -import numpy as np import torch import torch.nn as nn -from typing import Optional -from DGraph.Communicator import Communicator -from dist_utils import SingleProcessDummyCommunicator +from DGraph.utils.TimingReport import TimingReport + +""" +Local only layers for mesh processing. These layers do not perform any communication and can be used in both GraphCast and MeshGraphNet.""" class MeshGraphMLP(nn.Module): @@ -72,7 +70,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: The transformed tensor """ - return self._model(x) + with TimingReport("MeshGraphMLP/forward"): + return self._model(x) class MeshNodeBlock(nn.Module): @@ -83,7 +82,6 @@ def __init__( input_node_dim: int, input_edge_dim: int, output_node_dim: int, - comm: Union[Communicator, SingleProcessDummyCommunicator], hidden_dim: int = 512, num_hidden_layers: int = 1, aggregation_type: str = "sum", @@ -102,7 +100,6 @@ def __init__( super(MeshNodeBlock, self).__init__() assert aggregation_type in ["sum"], "Only sum aggregation is supported for now." self.aggregation_type = aggregation_type - self.comm = comm self.mesh_mlp = MeshGraphMLP( input_dim=input_node_dim + input_edge_dim, output_dim=output_node_dim, @@ -115,7 +112,6 @@ def forward( node_features: torch.Tensor, edge_features: torch.Tensor, src_indices: torch.Tensor, - rank_mapping: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Compute the node block @@ -129,17 +125,23 @@ def forward( Returns: The updated node features """ - # Sum all the edge features for each node num_local_nodes = node_features.shape[0] - # TODO: This can be optimized by a fused gather-scatter operation - S.Z - - aggregated_edge_features = self.comm.scatter( - edge_features, src_indices, rank_mapping, num_local_nodes - ) - # Concatenate the node and edge features - x = torch.cat([node_features, aggregated_edge_features], dim=-1) - # Apply the MLP - node_features_new = self.mesh_mlp(x) + node_features + with TimingReport("MeshNodeBlock/scatter_add"): + aggregated_edge_features = torch.zeros( + num_local_nodes, + edge_features.shape[-1], + device=edge_features.device, + dtype=edge_features.dtype, + ) + aggregated_edge_features.scatter_add_( + 0, + src_indices.unsqueeze(-1).expand(-1, edge_features.shape[-1]), + edge_features, + ) + + with TimingReport("MeshNodeBlock/mlp"): + x = torch.cat([node_features, aggregated_edge_features], dim=-1) + node_features_new = self.mesh_mlp(x) + node_features return node_features_new @@ -152,7 +154,6 @@ def __init__( input_dst_node_dim: int, input_edge_dim: int, output_edge_dim: int, - comm: Union[Communicator, SingleProcessDummyCommunicator], hidden_dim: int = 512, num_hidden_layers: int = 1, aggregation_type: str = "sum", @@ -162,7 +163,6 @@ def __init__( input_node_dim (int): The dimensionality of the input node features. input_edge_dim (int): The dimensionality of the input edge features. output_edge_dim (int): The dimensionality of the output edge features. - comm (CommunicatorBase): The communicator to use for distributed training. hidden_dim (int, optional): The dimensionality of the hidden layers. Defaults to 512. aggregation_type (str, optional): The type of aggregation to use. Defaults to "sum". """ @@ -171,7 +171,6 @@ def __init__( super(MeshEdgeBlock, self).__init__() assert aggregation_type in ["sum"], "Only sum aggregation is supported for now." self.aggregation_type = aggregation_type - self.comm = comm self.mesh_mlp = MeshGraphMLP( input_dim=input_src_node_dim + input_dst_node_dim + input_edge_dim, output_dim=output_edge_dim, @@ -186,8 +185,6 @@ def forward( edge_features: torch.Tensor, src_indices: torch.Tensor, dst_indices: torch.Tensor, - src_rank_mapping: Optional[torch.Tensor] = None, - dst_rank_mapping: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Compute the edge block @@ -201,16 +198,13 @@ def forward( Returns: The updated edge features """ - # Concatenate the source and destination node features with the edge features - src_node_features = self.comm.gather( - src_node_features, src_indices, src_rank_mapping - ) - dst_node_features = self.comm.gather( - dst_node_features, dst_indices, dst_rank_mapping - ) - concatenated_features = torch.cat( - [src_node_features, dst_node_features, edge_features], dim=-1 - ) - # Apply the MLP - edge_features_new = self.mesh_mlp(concatenated_features) + edge_features + with TimingReport("MeshEdgeBlock/gather"): + src_node_features = src_node_features[src_indices] + dst_node_features = dst_node_features[dst_indices] + + with TimingReport("MeshEdgeBlock/mlp"): + concatenated_features = torch.cat( + [src_node_features, dst_node_features, edge_features], dim=-1 + ) + edge_features_new = self.mesh_mlp(concatenated_features) + edge_features return edge_features_new diff --git a/experiments/GraphCast/model.py b/experiments/GraphCast/model.py index da6459b..9183ad8 100644 --- a/experiments/GraphCast/model.py +++ b/experiments/GraphCast/model.py @@ -14,11 +14,14 @@ import torch import torch.nn as nn -from typing import Optional, Tuple +from typing import Tuple from torch import Tensor from layers import MeshEdgeBlock, MeshGraphMLP, MeshNodeBlock from graphcast_config import Config from data_utils.graphcast_graph import DistributedGraphCastGraph +from DGraph.distributed import HaloExchange +from DGraph.distributed.commInfo import CommunicationPattern +from DGraph.utils.TimingReport import TimingReport class GraphCastEmbedder(nn.Module): @@ -116,55 +119,62 @@ def __init__(self, cfg: Config, comm, *args, **kwargs) -> None: comm: Communicator object """ super().__init__(*args, **kwargs) + hidden_dim = cfg.model.hidden_dim - edge_block_invars = ( - cfg.model.hidden_dim, - cfg.model.hidden_dim, - cfg.model.hidden_dim, - cfg.model.hidden_dim, - comm, - cfg.model.hidden_dim, - ) - self.edge_mlp = MeshEdgeBlock(*edge_block_invars) - - node_block_invars = ( - cfg.model.hidden_dim, - cfg.model.hidden_dim, - cfg.model.hidden_dim, - comm, - cfg.model.hidden_dim, + self.exchanger = HaloExchange(comm) + + self.edge_mlp = MeshEdgeBlock( + input_src_node_dim=hidden_dim, + input_dst_node_dim=hidden_dim, + input_edge_dim=hidden_dim, + output_edge_dim=hidden_dim, + hidden_dim=hidden_dim, ) - self.mesh_node_mlp = MeshNodeBlock(*node_block_invars) - self.grid_node_mlp = MeshGraphMLP( - input_dim=cfg.model.hidden_dim, output_dim=cfg.model.hidden_dim + self.mesh_node_mlp = MeshNodeBlock( + input_node_dim=hidden_dim, + input_edge_dim=hidden_dim, + output_node_dim=hidden_dim, + hidden_dim=hidden_dim, ) + self.grid_node_mlp = MeshGraphMLP(input_dim=hidden_dim, output_dim=hidden_dim) def forward( self, grid_node_features: Tensor, mesh_node_features: Tensor, grid2mesh_edge_features: Tensor, - grid2mesh_edge_indices_src: Tensor, - grid2mesh_edge_indices_dst: Tensor, + comm_pattern: CommunicationPattern, ) -> Tuple[Tensor, Tensor]: + # local_edge_list: [E, 2] with [central=mesh, neighbor=grid/halo] + edge_index = comm_pattern.local_edge_list + dst_indices = edge_index[:, 0] # mesh (central, aggregation target) + src_indices = edge_index[:, 1] # grid/halo (neighbor, message source) + num_local = comm_pattern.num_local_vertices + + with TimingReport("encoder/halo_exchange"): + halo_features = self.exchanger(mesh_node_features, comm_pattern) + augmented = torch.cat([mesh_node_features, halo_features], dim=0) + + with TimingReport("encoder/edge_block"): + e_feats = self.edge_mlp( + src_node_features=augmented, + dst_node_features=augmented, + edge_features=grid2mesh_edge_features, + src_indices=src_indices, + dst_indices=dst_indices, + ) - e_feats = self.edge_mlp( - src_node_features=grid_node_features, - dst_node_features=mesh_node_features, - edge_features=grid2mesh_edge_features, - src_indices=grid2mesh_edge_indices_src, - dst_indices=grid2mesh_edge_indices_dst, - ) + with TimingReport("encoder/node_block"): + n_feats = self.mesh_node_mlp( + node_features=augmented[:num_local], + edge_features=e_feats, + src_indices=dst_indices, + ) - n_feats = self.mesh_node_mlp( - node_features=mesh_node_features, - edge_features=e_feats, - src_indices=grid2mesh_edge_indices_dst, - ) + with TimingReport("encoder/grid_mlp"): + grid_node_features = grid_node_features + self.grid_node_mlp(grid_node_features) mesh_node_features = mesh_node_features + n_feats - grid_node_features = grid_node_features + self.grid_node_mlp(grid_node_features) - return grid_node_features, mesh_node_features @@ -179,54 +189,73 @@ def __init__(self, cfg: Config, comm, *args, **kwargs): comm: Communicator object """ super().__init__() + hidden_dim = cfg.model.hidden_dim processor_layers = cfg.model.processor_layers - node_block_invars = ( - cfg.model.hidden_dim, - cfg.model.hidden_dim, - cfg.model.hidden_dim, - comm, - cfg.model.hidden_dim, + + self.exchanger = HaloExchange(comm) + + self.edge_processors = nn.ModuleList( + [ + MeshEdgeBlock( + input_src_node_dim=hidden_dim, + input_dst_node_dim=hidden_dim, + input_edge_dim=hidden_dim, + output_edge_dim=hidden_dim, + hidden_dim=hidden_dim, + ) + for _ in range(processor_layers) + ] ) - edge_block_invars = ( - cfg.model.hidden_dim, - cfg.model.hidden_dim, - cfg.model.hidden_dim, - cfg.model.hidden_dim, - comm, - cfg.model.hidden_dim, + self.node_processors = nn.ModuleList( + [ + MeshNodeBlock( + input_node_dim=hidden_dim, + input_edge_dim=hidden_dim, + output_node_dim=hidden_dim, + hidden_dim=hidden_dim, + ) + for _ in range(processor_layers) + ] ) - edge_layers = [] - node_layers = [] - for _ in range(processor_layers): - edge_layers.append(MeshEdgeBlock(*edge_block_invars)) - for _ in range(processor_layers): - node_layers.append(MeshNodeBlock(*node_block_invars)) - - self.edge_processors = nn.ModuleList(edge_layers) - self.node_processors = nn.ModuleList(node_layers) def forward( self, embedded_mesh_features: Tensor, embedded_mesh2mesh_edge_features: Tensor, - mesh2mesh_edge_indices_src: Tensor, - mesh2mesh_edge_indices_dst: Tensor, + comm_pattern: CommunicationPattern, ) -> Tuple[Tensor, Tensor]: e_feats = embedded_mesh2mesh_edge_features n_feats = embedded_mesh_features - for edge_layer, node_layer in zip(self.edge_processors, self.node_processors): - e_feats = edge_layer( - n_feats, - n_feats, - e_feats, - mesh2mesh_edge_indices_src, - mesh2mesh_edge_indices_dst, - ) - n_feats = node_layer( - n_feats, - e_feats, - mesh2mesh_edge_indices_src, - ) + + # local_edge_list: [E, 2] with [central=mesh_dst, neighbor=mesh_src] + edge_index = comm_pattern.local_edge_list + dst_indices = edge_index[:, 0] # central (aggregation target) + src_indices = edge_index[:, 1] # neighbor (message source) + num_local = comm_pattern.num_local_vertices + + for i, (edge_layer, node_layer) in enumerate( + zip(self.edge_processors, self.node_processors) + ): + with TimingReport(f"processor/layer_{i}/halo_exchange"): + halo_features = self.exchanger(n_feats, comm_pattern) + augmented = torch.cat([n_feats, halo_features], dim=0) + + with TimingReport(f"processor/layer_{i}/edge_block"): + e_feats = edge_layer( + src_node_features=augmented, + dst_node_features=augmented, + edge_features=e_feats, + src_indices=src_indices, + dst_indices=dst_indices, + ) + + with TimingReport(f"processor/layer_{i}/node_block"): + n_feats = node_layer( + node_features=augmented[:num_local], + edge_features=e_feats, + src_indices=dst_indices, + ) + return n_feats, e_feats @@ -243,26 +272,22 @@ def __init__(self, cfg: Config, comm, *args, **kwargs): comm: Communicator object """ super().__init__() - edge_block_invars = ( - cfg.model.hidden_dim, - cfg.model.hidden_dim, - cfg.model.hidden_dim, - cfg.model.hidden_dim, - comm, - cfg.model.hidden_dim, + hidden_dim = cfg.model.hidden_dim + + self.exchanger = HaloExchange(comm) + + self.edge_mlp = MeshEdgeBlock( + input_src_node_dim=hidden_dim, + input_dst_node_dim=hidden_dim, + input_edge_dim=hidden_dim, + output_edge_dim=hidden_dim, + hidden_dim=hidden_dim, ) - self.comm = comm - self.edge_mlp = MeshEdgeBlock(*edge_block_invars) - dst_node_input_dim = cfg.model.hidden_dim - dst_node_output_dim = cfg.model.hidden_dim - m2g_edge_output_dim = cfg.model.hidden_dim self.node_mlp = MeshNodeBlock( - input_node_dim=dst_node_input_dim, - input_edge_dim=m2g_edge_output_dim, - output_node_dim=dst_node_output_dim, - hidden_dim=cfg.model.hidden_dim, - comm=comm, - num_hidden_layers=1, + input_node_dim=hidden_dim, + input_edge_dim=hidden_dim, + output_node_dim=hidden_dim, + hidden_dim=hidden_dim, ) def forward( @@ -270,42 +295,47 @@ def forward( mesh2grid_edge_features: Tensor, grid_node_features: Tensor, mesh_node_features: Tensor, - mesh2grid_edge_indices_src: Tensor, - mesh2grid_edge_indices_dst: Tensor, + comm_pattern: CommunicationPattern, ) -> Tensor: """ Args: mesh2grid_edge_features (Tensor): The edge features from the mesh to the grid grid_node_features (Tensor): The grid node features mesh_node_features (Tensor): The mesh node features - mesh2grid_edge_indices_src (Tensor): The source indices for the mesh2grid - bipartitate edges. These are the indices - of the mesh nodes that are connected to - the grid nodes. - mesh2grid_edge_indices_dst (Tensor): The destination indices for the mesh2grid - bipartitate edges. These are the indices of - the grid nodes that are connected to the - mesh nodes. + comm_pattern (CommunicationPattern): Precomputed communication pattern + for the mesh2grid bipartite graph (partitioned by grid vertex placement). Returns: (Tensor): The updated grid node features """ - e_feats = self.edge_mlp( - src_node_features=mesh_node_features, - dst_node_features=grid_node_features, - edge_features=mesh2grid_edge_features, - src_indices=mesh2grid_edge_indices_src, - dst_indices=mesh2grid_edge_indices_dst, - ) - n_feats = self.node_mlp( - node_features=grid_node_features, - edge_features=e_feats, - src_indices=mesh2grid_edge_indices_dst, - ) + # local_edge_list: [E, 2] with [central=grid, neighbor=mesh/halo] + edge_index = comm_pattern.local_edge_list + dst_indices = edge_index[:, 0] # grid (central, aggregation target) + src_indices = edge_index[:, 1] # mesh/halo (neighbor, message source) + num_local = comm_pattern.num_local_vertices + + with TimingReport("decoder/halo_exchange"): + # Mesh nodes are the neighbors (sources); grid nodes are the central (destination). + halo_mesh_features = self.exchanger(mesh_node_features, comm_pattern) + augmented_mesh = torch.cat([mesh_node_features, halo_mesh_features], dim=0) + + with TimingReport("decoder/edge_block"): + e_feats = self.edge_mlp( + src_node_features=augmented_mesh, # mesh features (local + halo) + dst_node_features=grid_node_features, # grid features (destination side) + edge_features=mesh2grid_edge_features, + src_indices=src_indices, + dst_indices=dst_indices, + ) - n_feats = grid_node_features + n_feats + with TimingReport("decoder/node_block"): + n_feats = self.node_mlp( + node_features=grid_node_features[:num_local], # local grid nodes being updated + edge_features=e_feats, + src_indices=dst_indices, + ) - return n_feats + return grid_node_features + n_feats class DGraphCast(nn.Module): @@ -320,8 +350,7 @@ def __init__(self, cfg: Config, comm, *args, **kwargs): super().__init__() self.hidden_dim = cfg.model.hidden_dim self.output_grid_dim = cfg.model.output_grid_dim - self.comm = comm - self.embedder = GraphCastEmbedder(cfg=cfg, comm=comm, *args, **kwargs) + self.embedder = GraphCastEmbedder(cfg=cfg, *args, **kwargs) self.encoder = GraphCastEncoder(cfg=cfg, comm=comm, *args, **kwargs) self.processor = GraphCastProcessor(cfg=cfg, comm=comm, *args, **kwargs) self.decoder = GraphCastDecoder(cfg=cfg, comm=comm, *args, **kwargs) @@ -340,26 +369,21 @@ def forward( Returns: (Tensor): The predicted output grid """ - input_grid_features = input_grid_features.squeeze(0) input_mesh_features = static_graph.mesh_graph_node_features mesh2mesh_edge_features = static_graph.mesh_graph_edge_features grid2mesh_edge_features = static_graph.grid2mesh_graph_edge_features mesh2grid_edge_features = static_graph.mesh2grid_graph_edge_features - mesh2mesh_edge_indices_src = static_graph.mesh_graph_src_indices - mesh2mesh_edge_indices_dst = static_graph.mesh_graph_dst_indices - mesh2grid_edge_indices_src = static_graph.mesh2grid_graph_src_indices - mesh2grid_edge_indices_dst = static_graph.mesh2grid_graph_dst_indices - grid2mesh_edge_indices_src = static_graph.grid2mesh_graph_src_indices - grid2mesh_edge_indices_dst = static_graph.grid2mesh_graph_dst_indices - - out = self.embedder( - input_grid_features, - input_mesh_features, - mesh2mesh_edge_features, - grid2mesh_edge_features, - mesh2grid_edge_features, - ) + comm_patterns = static_graph.distributed_comm_patterns + + with TimingReport("model/embed"): + out = self.embedder( + input_grid_features, + input_mesh_features, + mesh2mesh_edge_features, + grid2mesh_edge_features, + mesh2grid_edge_features, + ) ( embedded_grid_features, embedded_mesh_features, @@ -367,28 +391,31 @@ def forward( embedded_grid2mesh_edge_features, embedded_mesh2grid_edge_features, ) = out - encoded_grid_features, encoded_mesh_features = self.encoder( - embedded_grid_features, - embedded_mesh_features, - embedded_grid2mesh_edge_features, - grid2mesh_edge_indices_src, - grid2mesh_edge_indices_dst, - ) - out = self.processor( - encoded_mesh_features, - embedded_mesh2mesh_edge_features, - mesh2mesh_edge_indices_src, - mesh2mesh_edge_indices_dst, - ) - processed_mesh_node_features, _ = out - x = self.decoder( - embedded_mesh2grid_edge_features, - encoded_grid_features, - processed_mesh_node_features, - mesh2grid_edge_indices_src, - mesh2grid_edge_indices_dst, - ) - output = self.final_prediction(x) + with TimingReport("model/encode"): + encoded_grid_features, encoded_mesh_features = self.encoder( + embedded_grid_features, + embedded_mesh_features, + embedded_grid2mesh_edge_features, + comm_patterns.grid2mesh, + ) + + with TimingReport("model/process"): + processed_mesh_node_features, _ = self.processor( + encoded_mesh_features, + embedded_mesh2mesh_edge_features, + comm_patterns.mesh, + ) + + with TimingReport("model/decode"): + x = self.decoder( + embedded_mesh2grid_edge_features, + encoded_grid_features, + processed_mesh_node_features, + comm_patterns.mesh2grid, + ) + + with TimingReport("model/final_prediction"): + output = self.final_prediction(x) output = input_grid_features + output return output diff --git a/experiments/cost_model_benchmarks/analysis/__init__.py b/experiments/cost_model_benchmarks/analysis/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/experiments/cost_model_benchmarks/analysis/compute_predictions.py b/experiments/cost_model_benchmarks/analysis/compute_predictions.py new file mode 100644 index 0000000..3e69f4d --- /dev/null +++ b/experiments/cost_model_benchmarks/analysis/compute_predictions.py @@ -0,0 +1,160 @@ +"""Analysis — Apply Assembled Cost Model and Compare to Measurements. + +Reads: + - data/fitted_primitives.json + - data/fitted_overhead.json + - All end-to-end JSON run files + +For every run, computes the predicted T_layer: + + T_layer = T_comp + max(T_intra, T_inter) + T_buffer_copy + T_overhead + +and records the measured median, predicted value, absolute error, and relative +error (as a fraction). + +Also computes aggregate MAPE for the fit subset and the held-out subset +(using the same --fit-filter expression as fit_overhead.py). + +Outputs ``data/predictions.json``. + +Usage:: + + python -m analysis.compute_predictions \\ + --primitives data/fitted_primitives.json \\ + --overhead data/fitted_overhead.json \\ + --e2e-runs data/e2e_*.json \\ + --fit-filter "world_size <= 8" \\ + --output data/predictions.json +""" + +import argparse +import json +from pathlib import Path + +import numpy as np + +from analysis.fit_overhead import ( + apply_filter, + load_e2e_runs, + predict_layer_time, +) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def parse_args(): + p = argparse.ArgumentParser(description="Apply cost model and compute predictions") + p.add_argument("--primitives", type=str, required=True) + p.add_argument("--overhead", type=str, required=True) + p.add_argument("--e2e-runs", nargs="+", required=True, metavar="FILE") + p.add_argument("--fit-filter", type=str, default="world_size <= 8", + help="Same expression used when fitting overhead (determines train/test split)") + p.add_argument("--output", type=str, default="data/predictions.json") + return p.parse_args() + + +def main(): + args = parse_args() + + with open(args.primitives) as f: + primitives = json.load(f) + with open(args.overhead) as f: + overhead_data = json.load(f) + + T_overhead = overhead_data.get("overhead_seconds", 0.0) + + all_runs = load_e2e_runs(args.e2e_runs) + fit_runs, held_runs = apply_filter(all_runs, args.fit_filter) + + fit_set = set(id(r) for r in fit_runs) + + prediction_entries = [] + for r in all_runs: + T_model_base = predict_layer_time(r["config"], r["per_rank_stats"], primitives) + T_pred = T_model_base + T_overhead + T_meas = r["measured_median"] + abs_err = abs(T_meas - T_pred) + rel_err = abs_err / T_meas if T_meas > 0 else float("nan") + + # Decompose prediction for ablation figures + net = primitives.get("network", {}) + intra_bytes = r["per_rank_stats"].get("c_intra_bytes", 0) + inter_bytes = r["per_rank_stats"].get("c_inter_bytes", 0) + + def net_time(nbytes, mode): + params = net.get(mode, None) + if params is None or nbytes == 0: + return 0.0 + return params.get("latency_seconds", 0.0) + nbytes / params.get("bandwidth_bytes_per_sec", 1e10) + + T_intra = net_time(intra_bytes, "intra") + T_inter = net_time(inter_bytes, "inter") + T_comm = max(T_intra, T_inter) + + F = r["config"]["feature_dim"] + send_bytes = r["per_rank_stats"].get("send_total", 0) * F * 4 + gath_params = primitives.get("gather", {}).get("clustered", {}).get("gather", None) + T_buffer_copy = 0.0 + if gath_params and send_bytes > 0: + B_g = gath_params.get("bandwidth_bytes_per_sec", 1e12) + T_buffer_copy = gath_params.get("intercept_seconds", 0.0) + send_bytes / B_g + + entry = { + "source_file": r["source_file"], + "config": r["config"], + "partition_stats": r["per_rank_stats"], + "measured_median_seconds": T_meas, + "predicted_seconds": T_pred, + "absolute_error_seconds": abs_err, + "relative_error": rel_err, + "in_fit_set": id(r) in fit_set, + "breakdown": { + "T_comp_seconds": T_model_base - T_comm - T_buffer_copy, + "T_comm_seconds": T_comm, + "T_buffer_copy_seconds": T_buffer_copy, + "T_overhead_seconds": T_overhead, + }, + } + prediction_entries.append(entry) + + # Aggregate MAPE + def mape(entries): + errs = [e["relative_error"] for e in entries + if not np.isnan(e["relative_error"])] + return float(np.mean(errs)) if errs else float("nan") + + fit_entries = [e for e in prediction_entries if e["in_fit_set"]] + held_entries = [e for e in prediction_entries if not e["in_fit_set"]] + + mape_fit = mape(fit_entries) + mape_held = mape(held_entries) + mape_total = mape(prediction_entries) + + print(f"[predictions] Fit MAPE={mape_fit*100:.2f}% " + f"Held-out MAPE={mape_held*100:.2f}% " + f"Total MAPE={mape_total*100:.2f}%") + + result = { + "fit_filter": args.fit_filter, + "T_overhead_seconds": T_overhead, + "aggregate": { + "mape_fit_set": mape_fit, + "mape_held_out": mape_held, + "mape_all": mape_total, + "num_fit": len(fit_entries), + "num_held_out": len(held_entries), + }, + "predictions": prediction_entries, + } + + out_path = Path(args.output) + out_path.parent.mkdir(parents=True, exist_ok=True) + with open(out_path, "w") as f: + json.dump(result, f, indent=2) + print(f"[predictions] Written to {out_path}") + + +if __name__ == "__main__": + main() diff --git a/experiments/cost_model_benchmarks/analysis/fit_overhead.py b/experiments/cost_model_benchmarks/analysis/fit_overhead.py new file mode 100644 index 0000000..426f440 --- /dev/null +++ b/experiments/cost_model_benchmarks/analysis/fit_overhead.py @@ -0,0 +1,226 @@ +"""Analysis — Fit Library Overhead Bias (T_overhead). + +Reads ``fitted_primitives.json`` and the small-K subset of end-to-end runs. +For each run it computes the model-predicted T_layer (without overhead), then +fits a single scalar T_overhead that minimises MAPE: + + MAPE = mean( |T_measured - (T_model + T_overhead)| / T_measured ) + +The subset used for fitting is controlled by ``--fit-filter``, which is +evaluated as a Python expression where each run's config fields are available +as local variables (e.g. ``"world_size <= 8"``). + +Outputs ``data/fitted_overhead.json``. + +Usage:: + + python -m analysis.fit_overhead \\ + --primitives data/fitted_primitives.json \\ + --e2e-runs data/e2e_*.json \\ + --fit-filter "world_size <= 8" \\ + --output data/fitted_overhead.json +""" + +import argparse +import json +from pathlib import Path + +import numpy as np + + +# --------------------------------------------------------------------------- +# Cost model (without overhead) +# --------------------------------------------------------------------------- + +def predict_layer_time(run_config: dict, per_rank_stats: dict, + primitives: dict) -> float: + """Predict T_layer for one rank using the assembled primitive model. + + T_layer = T_comp + max(T_intra, T_inter) + T_buffer_copy + + Parameters + ---------- + run_config : dict + Config block from the end-to-end JSON (feature_dim, model, etc.) + per_rank_stats : dict + Stats for rank 0 from per_rank_stats list. + primitives : dict + Loaded fitted_primitives.json. + """ + F = run_config["feature_dim"] + model_type = run_config.get("model", "gcn") + + n_local = per_rank_stats.get("n_local", 0) + n_halo = per_rank_stats.get("n_halo", 0) + n_total = n_local + n_halo + + # A rough edge count estimate: use avg_degree * n_local as a proxy + avg_degree = run_config.get("avg_degree", 20.0) + n_edges_local = int(n_local * avg_degree) + + # T_comp + comp_params = primitives.get("compute", {}).get(model_type, {}).get("forward", None) + if comp_params: + T_comp = (comp_params["coeff_V"] * n_total + + comp_params["coeff_E"] * n_edges_local + + comp_params["intercept"]) + T_comp = max(T_comp, 0.0) + else: + T_comp = 0.0 + + # T_intra and T_inter + intra_bytes = per_rank_stats.get("c_intra_bytes", 0) + inter_bytes = per_rank_stats.get("c_inter_bytes", 0) + + net = primitives.get("network", {}) + + def net_time(nbytes: int, mode: str) -> float: + params = net.get(mode, None) + if params is None or nbytes == 0: + return 0.0 + B = params.get("bandwidth_bytes_per_sec", 1e10) + t_L = params.get("latency_seconds", 0.0) + return t_L + nbytes / B + + T_intra = net_time(intra_bytes, "intra") + T_inter = net_time(inter_bytes, "inter") + T_comm = max(T_intra, T_inter) + + # T_buffer_copy (gather of send buffer) + send_bytes = per_rank_stats.get("send_total", 0) * F * 4 + gath_params = primitives.get("gather", {}).get("clustered", {}).get("gather", None) + if gath_params and send_bytes > 0: + B_g = gath_params.get("bandwidth_bytes_per_sec", 1e12) + T_buffer_copy = gath_params.get("intercept_seconds", 0.0) + send_bytes / B_g + else: + T_buffer_copy = 0.0 + + return T_comp + T_comm + T_buffer_copy + + +# --------------------------------------------------------------------------- +# Load helpers +# --------------------------------------------------------------------------- + +def load_e2e_runs(paths: list) -> list: + runs = [] + for p in paths: + with open(p) as f: + data = json.load(f) + config = data.get("config", {}) + for meas in data.get("measurements", []): + per_rank_stats = meas.get("per_rank_stats", [{}]) + rank0_stats = per_rank_stats[0] if per_rank_stats else {} + trials = meas.get("rank0_trials_seconds", []) + if not trials: + continue + runs.append({ + "config": config, + "per_rank_stats": rank0_stats, + "measured_median": float(np.median(trials)), + "source_file": str(p), + }) + return runs + + +def apply_filter(runs: list, filter_expr: str) -> tuple: + """Split runs into fit and held-out sets using filter_expr.""" + if not filter_expr: + return runs, [] + fit_runs, held_runs = [], [] + for r in runs: + env = dict(r["config"]) + env.update(r["per_rank_stats"]) + try: + if eval(filter_expr, {"__builtins__": {}}, env): + fit_runs.append(r) + else: + held_runs.append(r) + except Exception as e: + print(f"[fit_overhead] Warning: filter eval failed for run ({e}), including in fit set") + fit_runs.append(r) + return fit_runs, held_runs + + +# --------------------------------------------------------------------------- +# Scalar overhead fitting +# --------------------------------------------------------------------------- + +def fit_overhead_scalar(fit_runs: list, primitives: dict) -> tuple: + """Fit T_overhead to minimise MAPE on fit_runs. Returns (overhead, mape_in_sample).""" + if not fit_runs: + return 0.0, float("nan") + + residuals = [] + for r in fit_runs: + T_model = predict_layer_time(r["config"], r["per_rank_stats"], primitives) + residuals.append(r["measured_median"] - T_model) + + # Optimal scalar overhead that minimises sum of |err - overhead| / T_meas + # is the weighted median; for uniform weights it's just the median of residuals. + overhead = float(np.median(residuals)) + + mape = float(np.mean([ + abs(r["measured_median"] - (predict_layer_time(r["config"], r["per_rank_stats"], primitives) + overhead)) + / r["measured_median"] + for r in fit_runs + ])) + return overhead, mape + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def parse_args(): + p = argparse.ArgumentParser(description="Fit T_overhead from end-to-end runs") + p.add_argument("--primitives", type=str, required=True, + help="Path to fitted_primitives.json") + p.add_argument("--e2e-runs", nargs="+", required=True, metavar="FILE") + p.add_argument("--fit-filter", type=str, default="world_size <= 8", + help="Python expression evaluated per run; True → fit set") + p.add_argument("--output", type=str, default="data/fitted_overhead.json") + return p.parse_args() + + +def main(): + args = parse_args() + + with open(args.primitives) as f: + primitives = json.load(f) + + all_runs = load_e2e_runs(args.e2e_runs) + print(f"[fit_overhead] Loaded {len(all_runs)} run(s)") + + fit_runs, held_runs = apply_filter(all_runs, args.fit_filter) + print(f"[fit_overhead] Fit set: {len(fit_runs)} Held-out: {len(held_runs)}") + + overhead, mape_in = fit_overhead_scalar(fit_runs, primitives) + print(f"[fit_overhead] T_overhead = {overhead*1e3:.3f} ms in-sample MAPE = {mape_in*100:.2f}%") + + result = { + "overhead_seconds": overhead, + "fit_filter": args.fit_filter, + "num_fit_points": len(fit_runs), + "num_held_out": len(held_runs), + "in_sample_mape": mape_in, + "fit_subset_runs": [ + { + "source_file": r["source_file"], + "world_size": r["config"].get("world_size"), + "feature_dim": r["config"].get("feature_dim"), + "measured_median_seconds": r["measured_median"], + } + for r in fit_runs + ], + } + + out_path = Path(args.output) + out_path.parent.mkdir(parents=True, exist_ok=True) + with open(out_path, "w") as f: + json.dump(result, f, indent=2) + print(f"[fit_overhead] Written to {out_path}") + + +if __name__ == "__main__": + main() diff --git a/experiments/cost_model_benchmarks/analysis/fit_primitives.py b/experiments/cost_model_benchmarks/analysis/fit_primitives.py new file mode 100644 index 0000000..6211f94 --- /dev/null +++ b/experiments/cost_model_benchmarks/analysis/fit_primitives.py @@ -0,0 +1,343 @@ +"""Analysis — Fit Primitive Cost-Model Parameters. + +Reads JSON outputs from benchmarks 1.1, 1.3, and 1.4, fits the following +parameters by linear regression on **medians** of per-trial times: + +* Network: T = t_L + bytes / B → fit (t_L, B) for intra and inter +* Compute: T = coeff_V * |V| + coeff_E * |E| + intercept + (separate fits for GCN and edge-conditioned models) +* Gather: T = intercept + bytes / B_gather + (separate fits for contiguous, clustered, random distributions) + +Writes ``data/fitted_primitives.json``. + +Usage:: + + python -m analysis.fit_primitives \\ + --pingpong-intra data/pingpong_intra_*.json \\ + --pingpong-inter data/pingpong_inter_*.json \\ + --compute-gcn data/compute_gcn_*.json \\ + --compute-edge data/compute_edge_*.json \\ + --gather-contiguous data/gather_contiguous_*.json \\ + --gather-clustered data/gather_clustered_*.json \\ + --gather-random data/gather_random_*.json \\ + --output data/fitted_primitives.json +""" + +import argparse +import json +from pathlib import Path + +import numpy as np +from scipy import stats as sp_stats +from scipy.optimize import curve_fit +from scipy.special import expit + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def load_json_files(paths: list) -> list: + records = [] + for p in paths: + with open(p) as f: + records.append(json.load(f)) + return records + + +def median_of_trials(trials: list) -> float: + return float(np.median(trials)) + + +def r_squared(y_true: np.ndarray, y_pred: np.ndarray) -> float: + ss_res = np.sum((y_true - y_pred) ** 2) + ss_tot = np.sum((y_true - np.mean(y_true)) ** 2) + return float(1.0 - ss_res / ss_tot) if ss_tot > 0 else 1.0 + + +def linear_fit(x: np.ndarray, y: np.ndarray): + """Fit y = slope * x + intercept via scipy linregress. Returns dict.""" + result = sp_stats.linregress(x, y) + y_pred = result.slope * x + result.intercept + r2 = r_squared(y, y_pred) + return { + "slope": float(result.slope), + "intercept": float(result.intercept), + "r_squared": r2, + } + + +# --------------------------------------------------------------------------- +# Network fit: T = t_L + bytes / B +# --------------------------------------------------------------------------- + + +def fit_network(records: list) -> dict: + """Fit (t_L, B) from ping-pong records (one mode per call).""" + bytes_arr = [] + time_arr = [] + for rec in records: + for meas in rec["measurements"]: + nbytes = meas["params"]["message_bytes"] + t_med = median_of_trials(meas["trials_seconds"]) + bytes_arr.append(nbytes) + time_arr.append(t_med) + + bytes_arr = np.array(bytes_arr, dtype=float) + time_arr = np.array(time_arr, dtype=float) + + # T = t_L + bytes / B → T = intercept + slope * bytes + # so slope = 1/B, intercept = t_L + fit = linear_fit(bytes_arr, time_arr) + bandwidth = 1.0 / fit["slope"] if fit["slope"] > 0 else float("nan") + latency = fit["intercept"] + return { + "bandwidth_bytes_per_sec": bandwidth, + "latency_seconds": latency, + "r_squared": fit["r_squared"], + "_raw_slope": fit["slope"], + "_raw_intercept": fit["intercept"], + "_num_points": len(bytes_arr), + } + + +# --------------------------------------------------------------------------- +# Compute fit: T = coeff_V * |V| + coeff_E * |E| + intercept +# --------------------------------------------------------------------------- + + +def fit_compute(records: list, timing_key: str = "forward_trials_seconds") -> dict: + """Fit compute cost as a function of |V| and |E|. + + Uses multiple linear regression: T = a * |V| + b * |E| + c + """ + V_arr, E_arr, T_arr = [], [], [] + for rec in records: + for meas in rec["measurements"]: + V_arr.append(meas["params"]["num_vertices"]) + E_arr.append(meas["params"]["num_edges"]) + T_arr.append(median_of_trials(meas[timing_key])) + + V_arr = np.array(V_arr, dtype=float) + E_arr = np.array(E_arr, dtype=float) + T_arr = np.array(T_arr, dtype=float) + + # Design matrix: [V, E, 1] + A = np.column_stack([V_arr, E_arr, np.ones_like(V_arr)]) + result, _, _, _ = np.linalg.lstsq(A, T_arr, rcond=None) + coeff_V, coeff_E, intercept = result + T_pred = A @ result + r2 = r_squared(T_arr, T_pred) + + return { + "coeff_V": float(coeff_V), + "coeff_E": float(coeff_E), + "intercept": float(intercept), + "r_squared": r2, + "_num_points": len(T_arr), + } + + +# --------------------------------------------------------------------------- +# Gather fit: T = intercept + max(overhead, bytes / B_gather) +# --------------------------------------------------------------------------- + + +def fit_gather(records: list, timing_key: str = "gather_trials_seconds") -> dict: + k_arr, T_arr, F_arr = [], [], [] + for rec in records: + F = rec["config"]["feature_dim"] + for meas in rec["measurements"]: + k = meas["params"]["k"] + t_med = median_of_trials(meas[timing_key]) + k_arr.append(k) + T_arr.append(t_med) + F_arr.append(F) + + k_arr = np.array(k_arr, dtype=float) + F_arr = np.array(F_arr, dtype=float) + T_arr = np.array(T_arr, dtype=float) + + bytes_arr = k_arr * F_arr * 4.0 # float32 + + # 1. The Piecewise Linear Model (No Logarithms) + def time_model(b, overhead, inv_bw_L2, inv_bw_HBM, L2_thresh, HBM_thresh): + # 1. Bucket the bytes into their respective physical regimes + # Bytes processed exclusively at L2 speeds + bytes_L2 = np.clip(b, 0, L2_thresh) + # Bytes processed exclusively at HBM speeds + bytes_HBM = np.maximum(0, b - HBM_thresh) + + # 2. Apply the specific bandwidth (slope) to each bucket + t_mem = (bytes_L2 * inv_bw_L2) + (bytes_HBM * inv_bw_HBM) + + # 3. Floor the total time by the kernel launch overhead + return np.maximum(overhead, t_mem) + + # 2. Strategic initial guesses + min_T, max_T = float(np.min(T_arr)), float(np.max(T_arr)) + max_b = float(np.max(bytes_arr)) + + # The asymptotic slope is the difference in max/min time over max bytes + inv_bw_HBM_guess = (max_T - min_T) / (max_b + 1e-9) + + p0 = [ + min_T, # overhead + inv_bw_HBM_guess * 0.3, # inv_bw_L2 + inv_bw_HBM_guess, # inv_bw_HBM + max_b * 0.00001, # L2_thresh + max_b * 0.001, # HBM_thresh + ] + + try: + bounds = ([0, 0, 0, 0, 0], [np.inf, np.inf, np.inf, max_b, max_b]) + + # 3. The Magic Fix: sigma=T_arr weights the fit by relative error + popt, _ = curve_fit( + time_model, + bytes_arr, + T_arr, + p0=p0, + bounds=bounds, + method="trf", + sigma=T_arr, + absolute_sigma=False, + ) + overhead, inv_bw_L2, inv_bw_HBM, L2_thresh, HBM_thresh = popt + + except Exception as e: + print(f"Fit failed: {e}") + overhead = inv_bw_L2 = inv_bw_HBM = L2_thresh = np.nan + + bw_HBM = 1.0 / inv_bw_HBM if inv_bw_HBM > 0 else float("nan") + bw_L2 = 1.0 / inv_bw_L2 if inv_bw_L2 > 0 else float("nan") + + # 3. Calculate linear R-squared in linear space + if not np.isnan(overhead): + T_pred = time_model( + bytes_arr, overhead, inv_bw_L2, inv_bw_HBM, L2_thresh, HBM_thresh + ) + ss_res = np.sum((T_arr - T_pred) ** 2) + ss_tot = np.sum((T_arr - np.mean(T_arr)) ** 2) + r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else float("nan") + else: + r_squared = float("nan") + + print( + f" Fitted gather: overhead={overhead*1e3:.3f} ms ", + f"BW_L2={bw_L2/1e9:.2f} GB/s BW_HBM={bw_HBM/1e9:.2f} GB/s " + f"L2_thresh={L2_thresh/1e6:.2f} MB HBM_thresh={HBM_thresh/1e6:.2f} MB " + f"R²={r_squared:.4f}", + ) + + return { + "bandwidth_bytes_per_sec": bw_HBM, + "L2_bandwidth_bytes_per_sec": bw_L2, + "L2_inflection_bytes": L2_thresh, + "HBM_inflection_bytes": HBM_thresh, + "launch_overhead_seconds": overhead, + "r_squared": r_squared, + } + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def parse_args(): + p = argparse.ArgumentParser(description="Fit cost-model primitive parameters") + p.add_argument("--pingpong-intra", nargs="+", default=[], metavar="FILE") + p.add_argument("--pingpong-inter", nargs="+", default=[], metavar="FILE") + p.add_argument("--compute-gcn", nargs="+", default=[], metavar="FILE") + p.add_argument("--compute-edge", nargs="+", default=[], metavar="FILE") + p.add_argument("--gather-contiguous", nargs="+", default=[], metavar="FILE") + p.add_argument("--gather-clustered", nargs="+", default=[], metavar="FILE") + p.add_argument("--gather-random", nargs="+", default=[], metavar="FILE") + p.add_argument("--output", type=str, default="data/fitted_primitives.json") + return p.parse_args() + + +def main(): + args = parse_args() + result = {} + + # Network + net = {} + if args.pingpong_intra: + recs = load_json_files(args.pingpong_intra) + net["intra"] = fit_network(recs) + print( + f"[network/intra] B={net['intra']['bandwidth_bytes_per_sec']/1e9:.2f} GB/s " + f"t_L={net['intra']['latency_seconds']*1e6:.2f} µs " + f"R²={net['intra']['r_squared']:.4f}" + ) + if args.pingpong_inter: + recs = load_json_files(args.pingpong_inter) + net["inter"] = fit_network(recs) + print( + f"[network/inter] B={net['inter']['bandwidth_bytes_per_sec']/1e9:.2f} GB/s " + f"t_L={net['inter']['latency_seconds']*1e6:.2f} µs " + f"R²={net['inter']['r_squared']:.4f}" + ) + result["network"] = net + + # Compute + comp = {} + if args.compute_gcn: + recs = load_json_files(args.compute_gcn) + comp["gcn"] = { + "forward": fit_compute(recs, "forward_trials_seconds"), + "backward": fit_compute(recs, "backward_trials_seconds"), + } + print( + f"[compute/gcn] coeff_V={comp['gcn']['forward']['coeff_V']:.3e} " + f"coeff_E={comp['gcn']['forward']['coeff_E']:.3e} " + f"R²={comp['gcn']['forward']['r_squared']:.4f}" + ) + if args.compute_edge: + recs = load_json_files(args.compute_edge) + comp["edge"] = { + "forward": fit_compute(recs, "forward_trials_seconds"), + "backward": fit_compute(recs, "backward_trials_seconds"), + } + print( + f"[compute/edge] coeff_V={comp['edge']['forward']['coeff_V']:.3e} " + f"coeff_E={comp['edge']['forward']['coeff_E']:.3e} " + f"R²={comp['edge']['forward']['r_squared']:.4f}" + ) + result["compute"] = comp + + # Gather + gath = {} + for dist_name, files_attr in [ + ("contiguous", "gather_contiguous"), + ("clustered", "gather_clustered"), + ("random", "gather_random"), + ]: + files = getattr(args, files_attr.replace("-", "_")) + if files: + recs = load_json_files(files) + gath[dist_name] = { + "gather": fit_gather(recs, "gather_trials_seconds"), + "scatter_add": fit_gather(recs, "scatter_add_trials_seconds"), + } + print( + f"[gather/{dist_name}] " + f"B_gather={gath[dist_name]['gather']['bandwidth_bytes_per_sec']/1e9:.2f} GB/s " + f"R²={gath[dist_name]['gather']['r_squared']:.4f}" + ) + result["gather"] = gath + + # Write + out_path = Path(args.output) + out_path.parent.mkdir(parents=True, exist_ok=True) + with open(out_path, "w") as f: + json.dump(result, f, indent=2) + print(f"[fit_primitives] Written to {out_path}") + + +if __name__ == "__main__": + main() diff --git a/experiments/cost_model_benchmarks/benchmarks/__init__.py b/experiments/cost_model_benchmarks/benchmarks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/experiments/cost_model_benchmarks/benchmarks/bench_compute.py b/experiments/cost_model_benchmarks/benchmarks/bench_compute.py new file mode 100644 index 0000000..89f5249 --- /dev/null +++ b/experiments/cost_model_benchmarks/benchmarks/bench_compute.py @@ -0,0 +1,232 @@ +"""Benchmark 1.3 — GNN Layer Compute Primitive. + +Single-GPU benchmark. Fits f_comp(|Ṽ|, |Ẽ|) for two message-function +variants: + +* ``gcn`` — GCN-like: φ(h_b) = W h_b (source-only linear transform) +* ``edge`` — Edge-conditioned: φ(h_b, h_a, e_ba) = MLP([h_b, h_a, e_ba]) + with a 2-layer MLP (hidden dim = feature_dim) + +Two sweep modes (controlled by ``--sweep``): + +* ``vertices`` — vary |V| with |E| fixed at ``--fixed-value`` +* ``edges`` — vary |E| with |V| fixed at ``--fixed-value`` + +Usage:: + + python -m benchmarks.bench_compute \\ + --model edge --sweep vertices \\ + --min 1000 --max 100000 --steps 15 \\ + --fixed-value 500000 --feature-dim 128 \\ + --warmup 10 --trials 50 \\ + --output data/compute_edge_vswp.json --seed 42 +""" + +import argparse + +import numpy as np +import torch +import torch.nn as nn + +from benchmarks.common import ( + collect_metadata, + cuda_timed, + seed_everything, + write_result, +) + + +# --------------------------------------------------------------------------- +# Synthetic graph generation +# --------------------------------------------------------------------------- + + +def erdos_renyi_edges( + num_vertices: int, num_edges: int, device: torch.device +) -> torch.Tensor: + """Return an edge index tensor of shape [2, num_edges] (random, with replacement).""" + src = torch.randint(0, num_vertices, (num_edges,), device=device) + dst = torch.randint(0, num_vertices, (num_edges,), device=device) + return torch.stack([src, dst], dim=0) + + +# --------------------------------------------------------------------------- +# GNN layers +# --------------------------------------------------------------------------- + + +class GCNLayer(nn.Module): + """GCN-like: aggregate neighbour source features with a linear transform.""" + + def __init__(self, feature_dim: int): + super().__init__() + self.linear = nn.Linear(feature_dim, feature_dim, bias=False) + + def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor: + # x: [V, F], edge_index: [2, E] + src, dst = edge_index[0], edge_index[1] + # Message: transform source features + msg = self.linear(x[src]) # [E, F] + # Aggregate: scatter-add to destination + out = torch.zeros_like(x) + out.scatter_add_(0, dst.unsqueeze(1).expand_as(msg), msg) + return out + + +class EdgeConditionedLayer(nn.Module): + """Edge-conditioned: φ(h_b, h_a, e_ba) = MLP([h_b, h_a, e_ba]).""" + + def __init__(self, feature_dim: int): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(3 * feature_dim, feature_dim), + nn.ReLU(), + nn.Linear(feature_dim, feature_dim), + ) + + def forward( + self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor + ) -> torch.Tensor: + # x: [V, F], edge_index: [2, E], edge_attr: [E, F] + src, dst = edge_index[0], edge_index[1] + msg_input = torch.cat([x[src], x[dst], edge_attr], dim=-1) # [E, 3F] + msg = self.mlp(msg_input) # [E, F] + out = torch.zeros_like(x) + out.scatter_add_(0, dst.unsqueeze(1).expand_as(msg), msg) + return out + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def parse_args(): + p = argparse.ArgumentParser(description="GNN compute primitive benchmark") + p.add_argument("--model", choices=["gcn", "edge"], required=True) + p.add_argument("--sweep", choices=["vertices", "edges"], required=True) + p.add_argument("--min", type=int, default=1_000, dest="sweep_min") + p.add_argument("--max", type=int, default=1_000_000, dest="sweep_max") + p.add_argument("--steps", type=int, default=15) + p.add_argument( + "--fixed-value", + type=int, + default=500_000, + help="Fixed |E| when sweeping vertices, or fixed |V| when sweeping edges", + ) + p.add_argument("--feature-dim", type=int, default=128) + p.add_argument("--warmup", type=int, default=10) + p.add_argument("--trials", type=int, default=50) + p.add_argument("--output", type=str, required=True) + p.add_argument("--seed", type=int, default=42) + return p.parse_args() + + +def main(): + args = parse_args() + seed_everything(args.seed) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + F = args.feature_dim + + # Build model + if args.model == "gcn": + model = GCNLayer(F).to(device) + else: + model = EdgeConditionedLayer(F).to(device) + + # Sweep points + sweep_vals = np.unique( + np.round( + np.logspace( + np.log10(args.sweep_min), + np.log10(args.sweep_max), + num=args.steps, + ) + ).astype(int) + ).tolist() + + measurements = [] + for val in sweep_vals: + if args.sweep == "vertices": + num_v, num_e = val, args.fixed_value + else: + num_v, num_e = args.fixed_value, val + + # Synthetic data + x = torch.randn(num_v, F, device=device, requires_grad=True) + edge_index = erdos_renyi_edges(num_v, num_e, device) + edge_attr = ( + torch.randn(num_e, F, device=device) if args.model == "edge" else None + ) + + # Forward timing + def fwd(): + if args.model == "gcn": + model(x, edge_index) + else: + model(x, edge_index, edge_attr) + + fwd_times = cuda_timed(fwd, warmup=args.warmup, trials=args.trials) + + # Backward timing (run fwd first to get a graph) + if args.model == "gcn": + out = model(x, edge_index) + else: + out = model(x, edge_index, edge_attr) + loss_ref = out.sum() + + def bwd(): + if x.grad is not None: + x.grad.zero_() + if args.model == "gcn": + out_inner = model(x, edge_index) + else: + out_inner = model(x, edge_index, edge_attr) + out_inner.sum().backward() + + bwd_times = cuda_timed(bwd, warmup=args.warmup, trials=args.trials) + + measurements.append( + { + "params": { + "num_vertices": num_v, + "num_edges": num_e, + "sweep_var": args.sweep, + "sweep_value": val, + "model": args.model, + "feature_dim": F, + }, + "forward_trials_seconds": fwd_times, + "backward_trials_seconds": bwd_times, + } + ) + med_fwd = sorted(fwd_times)[len(fwd_times) // 2] + med_bwd = sorted(bwd_times)[len(bwd_times) // 2] + print( + f"[compute/{args.model}] |V|={num_v:>8} |E|={num_e:>9} " + f"fwd {1e3*med_fwd:.2f} ms bwd {1e3*med_bwd:.2f} ms" + ) + + payload = { + "benchmark": "compute", + "metadata": collect_metadata(), + "config": { + "model": args.model, + "sweep": args.sweep, + "sweep_min": args.sweep_min, + "sweep_max": args.sweep_max, + "steps": args.steps, + "fixed_value": args.fixed_value, + "feature_dim": F, + "warmup": args.warmup, + "trials": args.trials, + "seed": args.seed, + }, + "measurements": measurements, + } + write_result(args.output, payload) + + +if __name__ == "__main__": + main() diff --git a/experiments/cost_model_benchmarks/benchmarks/bench_concurrency.py b/experiments/cost_model_benchmarks/benchmarks/bench_concurrency.py new file mode 100644 index 0000000..8d28cfd --- /dev/null +++ b/experiments/cost_model_benchmarks/benchmarks/bench_concurrency.py @@ -0,0 +1,227 @@ +"""Benchmark 1.2 — Intra/Inter Concurrency Check. + +Verifies that T_both / max(T_intra, T_inter) ≈ 1, i.e. that NVLink and +InfiniBand transfers overlap when issued simultaneously. + +Requires exactly 4 ranks across 2 nodes: + Node A: rank 0 (A0), rank 1 (A1) + Node B: rank 2 (B0), rank 3 (B1) + +Three conditions at a fixed message size: + 1. intra-only — A0↔A1 and B0↔B1 (no cross-node traffic) + 2. inter-only — A0↔B0 and A1↔B1 (no intra-node traffic) + 3. concurrent — all four exchanges in the same window (separate streams) + +Each rank logs its own wall time per trial. Rank 0 collects and writes JSON. + +Usage:: + + srun -N 2 --ntasks-per-node 2 python -m benchmarks.bench_concurrency \\ + --message-bytes 16777216 --warmup 20 --trials 100 \\ + --output data/concurrency.json --seed 42 +""" + +import argparse +import os + +import torch +import torch.distributed as dist + +from benchmarks.common import ( + collect_metadata, + seed_everything, + setup_distributed, + write_result, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _exchange(tensor_send: torch.Tensor, tensor_recv: torch.Tensor, + peer: int, stream: torch.cuda.Stream) -> None: + """Non-blocking send+recv on *stream* with *peer*.""" + with torch.cuda.stream(stream): + send_op = dist.P2POp(dist.isend, tensor_send, peer) + recv_op = dist.P2POp(dist.irecv, tensor_recv, peer) + reqs = dist.batch_isend_irecv([send_op, recv_op]) + for req in reqs: + req.wait() + + +def timed_window(rank: int, + send_buf: torch.Tensor, + recv_buf: torch.Tensor, + peer: int, + stream: torch.cuda.Stream, + warmup: int, + trials: int) -> list: + """Time a single exchange window (one send+recv pair).""" + for _ in range(warmup): + _exchange(send_buf, recv_buf, peer, stream) + torch.cuda.synchronize() + dist.barrier() + + times = [] + for _ in range(trials): + dist.barrier() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + _exchange(send_buf, recv_buf, peer, stream) + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end) / 1_000.0) + return times + + +def timed_concurrent(rank: int, + intra_send: torch.Tensor, intra_recv: torch.Tensor, intra_peer: int, + inter_send: torch.Tensor, inter_recv: torch.Tensor, inter_peer: int, + intra_stream: torch.cuda.Stream, + inter_stream: torch.cuda.Stream, + warmup: int, + trials: int) -> list: + """Time intra and inter exchanges issued concurrently on separate streams.""" + for _ in range(warmup): + _exchange(intra_send, intra_recv, intra_peer, intra_stream) + _exchange(inter_send, inter_recv, inter_peer, inter_stream) + torch.cuda.synchronize() + dist.barrier() + + times = [] + for _ in range(trials): + dist.barrier() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + _exchange(intra_send, intra_recv, intra_peer, intra_stream) + _exchange(inter_send, inter_recv, inter_peer, inter_stream) + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end) / 1_000.0) + return times + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def parse_args(): + p = argparse.ArgumentParser(description="Intra/inter concurrency benchmark") + p.add_argument("--message-bytes", type=int, default=16_777_216) # 16 MiB + p.add_argument("--warmup", type=int, default=20) + p.add_argument("--trials", type=int, default=100) + p.add_argument("--output", type=str, required=True) + p.add_argument("--seed", type=int, default=42) + return p.parse_args() + + +def main(): + args = parse_args() + rank, world_size, local_rank = setup_distributed() + + if world_size != 4: + raise ValueError( + f"bench_concurrency requires exactly 4 ranks (got {world_size}).\n" + "Layout: rank 0,1 on node A; rank 2,3 on node B." + ) + + seed_everything(args.seed) + device = torch.device(f"cuda:{local_rank}") + + num_elems = max(1, args.message_bytes // 4) + send_buf = torch.randn(num_elems, dtype=torch.float32, device=device) + recv_buf = torch.zeros(num_elems, dtype=torch.float32, device=device) + + # Intra-node peers: 0↔1, 2↔3 + # Inter-node peers: 0↔2, 1↔3 + intra_peer = {0: 1, 1: 0, 2: 3, 3: 2}[rank] + inter_peer = {0: 2, 1: 3, 2: 0, 3: 1}[rank] + + intra_stream = torch.cuda.Stream(device=device) + inter_stream = torch.cuda.Stream(device=device) + + intra_send = send_buf.clone() + intra_recv = torch.zeros_like(recv_buf) + inter_send = send_buf.clone() + inter_recv = torch.zeros_like(recv_buf) + + # --- Condition 1: intra-only --- + times_intra = timed_window( + rank, intra_send, intra_recv, intra_peer, intra_stream, + args.warmup, args.trials + ) + dist.barrier() + + # --- Condition 2: inter-only --- + times_inter = timed_window( + rank, inter_send, inter_recv, inter_peer, inter_stream, + args.warmup, args.trials + ) + dist.barrier() + + # --- Condition 3: concurrent --- + times_concurrent = timed_concurrent( + rank, + intra_send, intra_recv, intra_peer, + inter_send, inter_recv, inter_peer, + intra_stream, inter_stream, + args.warmup, args.trials, + ) + dist.barrier() + + # Gather per-rank times to rank 0 + def gather_times(times_local): + obj = [None] * world_size + dist.all_gather_object(obj, times_local) + return obj + + intra_all = gather_times(times_intra) + inter_all = gather_times(times_inter) + conc_all = gather_times(times_concurrent) + + if rank == 0: + measurements = [ + { + "params": {"condition": "intra_only", "message_bytes": num_elems * 4}, + "per_rank_trials_seconds": intra_all, + }, + { + "params": {"condition": "inter_only", "message_bytes": num_elems * 4}, + "per_rank_trials_seconds": inter_all, + }, + { + "params": {"condition": "concurrent", "message_bytes": num_elems * 4}, + "per_rank_trials_seconds": conc_all, + }, + ] + payload = { + "benchmark": "concurrency", + "metadata": collect_metadata(), + "config": { + "message_bytes": args.message_bytes, + "warmup": args.warmup, + "trials": args.trials, + "world_size": world_size, + "seed": args.seed, + "rank_layout": "rank 0,1 on node A; rank 2,3 on node B", + }, + "measurements": measurements, + } + write_result(args.output, payload) + print( + f"[concurrency] intra median = " + f"{1e3*sorted(intra_all[0])[len(intra_all[0])//2]:.2f} ms | " + f"inter median = " + f"{1e3*sorted(inter_all[0])[len(inter_all[0])//2]:.2f} ms | " + f"concurrent median = " + f"{1e3*sorted(conc_all[0])[len(conc_all[0])//2]:.2f} ms" + ) + + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/experiments/cost_model_benchmarks/benchmarks/bench_crossover.py b/experiments/cost_model_benchmarks/benchmarks/bench_crossover.py new file mode 100644 index 0000000..c5c2922 --- /dev/null +++ b/experiments/cost_model_benchmarks/benchmarks/bench_crossover.py @@ -0,0 +1,557 @@ +"""Benchmark 2.2 — Multi-GPU Crossover Point. + +Sweeps graph size to identify the point at which distributing computation +across K GPUs overcomes the overhead of halo-exchange communication. + +For each graph size N in the sweep: + + * **Single-GPU baseline**: rank 0 runs forward+backward on the *complete* + graph on one GPU. Measures raw compute cost T_comp(N). + * **Multi-GPU distributed**: all K ranks execute partitioned + forward+backward with halo exchange. Measures + T_comp(N/K) + T_comm. + +The *crossover* N* is the smallest graph size where +T_single(N) > T_multi(K, N), i.e. where distributing first becomes +beneficial. + +When ``--no-dist`` is set the script runs the single-GPU sweep only +(useful for characterising baseline compute without a multi-GPU +allocation). + +Synthetic graphs: + * ``erdos_renyi`` — Erdős-Rényi with ``--avg-degree`` expected degree + * ``sbm`` — Stochastic Block Model; ``--sbm-inter-density`` + controls the fraction of inter-block edges + +Partitioners: + * ``random`` — assign each vertex to a uniformly random rank + * ``balanced`` — contiguous vertex blocks of equal size + * ``metis`` — balanced k-way via pymetis (skipped if not installed) + +Usage — distributed (torchrun):: + + torchrun --nnodes 1 --nproc_per_node 4 \\ + -m benchmarks.benchmark_crossover \\ + --graph erdos_renyi \\ + --graph-sizes 10000,100000,1000000,10000000 \\ + --avg-degree 20 --feature-dim 128 --model gcn \\ + --partitioner balanced --warmup 10 --trials 50 \\ + --output data/crossover_K8_F128_er_bal.json --seed 42 + +Usage — single-GPU baseline only:: + + python -m benchmarks.benchmark_crossover --no-dist \\ + --graph erdos_renyi \\ + --graph-sizes 10000,100000,1000000,10000000 \\ + --avg-degree 20 --feature-dim 128 --model gcn \\ + --warmup 10 --trials 50 \\ + --output data/crossover_single_F128.json --seed 42 +""" + +import argparse +import os + +import numpy as np +import torch +import torch.distributed as dist +import gc + +from benchmarks.common import ( + collect_metadata, + cuda_timed, + seed_everything, + setup_distributed, + write_result, +) +from benchmarks.graph_data_common import ( + gen_erdos_renyi, + gen_sbm, + partition_balanced, + partition_metis, + partition_random, +) +from benchmarks.nn_layer_common import GCNLayer, EdgeConditionedLayer + +from DGraph.distributed import ( + HaloExchange, + CommunicationPattern, + build_communication_pattern, +) +from DGraph import Communicator + + +# --------------------------------------------------------------------------- +# Argument parsing +# --------------------------------------------------------------------------- + + +def parse_args(): + p = argparse.ArgumentParser(description="Multi-GPU crossover point benchmark") + p.add_argument("--graph", choices=["erdos_renyi", "sbm"], default="erdos_renyi") + p.add_argument( + "--graph-sizes", + type=str, + default="100,200,400,800,1000,2000,4000,8000,10000,16000,32000,100000,200000,400000", + help="Comma-separated list of num_vertices values to sweep", + ) + p.add_argument("--avg-degree", type=float, default=20.0) + p.add_argument( + "--sbm-inter-density", + type=float, + default=0.1, + help="Fraction of inter-block edges for SBM graphs", + ) + p.add_argument("--feature-dim", type=int, default=128) + p.add_argument("--model", choices=["gcn", "edge"], default="gcn") + p.add_argument( + "--partitioner", choices=["random", "balanced", "metis"], default="balanced" + ) + p.add_argument("--warmup", type=int, default=10) + p.add_argument("--trials", type=int, default=50) + p.add_argument("--output", type=str, required=True) + p.add_argument("--seed", type=int, default=42) + p.add_argument( + "--no-dist", + action="store_true", + help="Run single-GPU sweep only (no distributed setup required)", + ) + return p.parse_args() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _gen_graph(graph_type, num_vertices, avg_degree, sbm_inter_density, seed): + """Generate a synthetic graph reproducibly for a given graph size.""" + rng = np.random.default_rng(seed) + if graph_type == "erdos_renyi": + return gen_erdos_renyi(num_vertices, avg_degree, rng) + else: + return gen_sbm(num_vertices, avg_degree, sbm_inter_density, rng) + + +def _intra_inter_halo(comm_pattern: CommunicationPattern, ranks_per_node: int) -> tuple: + """Return (intra_halo_vertices, inter_halo_vertices) from recv_offset.""" + rank = comm_pattern.rank + my_node = rank // ranks_per_node + recv_counts = ( + comm_pattern.recv_offset[1:] - comm_pattern.recv_offset[:-1] + ).tolist() + intra = 0 + inter = 0 + for r, count in enumerate(recv_counts): + if r == rank: + continue + if (r // ranks_per_node) == my_node: + intra += int(count) + else: + inter += int(count) + return intra, inter + + +def _build_single_gpu_tensors(num_vertices, edges_np, F, model, device): + """Allocate full-graph tensors on *device*. May raise cuda.OutOfMemoryError.""" + # GCNLayer / EdgeConditionedLayer expect edge_index as [2, E] + edge_t = torch.from_numpy(edges_np.T.copy()).long().to(device) + x = torch.randn(num_vertices, F, device=device, requires_grad=True) + if model == "gcn": + layer = GCNLayer(F).to(device) + edge_attr = None + else: + layer = EdgeConditionedLayer(F).to(device) + edge_attr = torch.randn(edges_np.shape[0], F, device=device) + layer.train() + return x, edge_t, layer, edge_attr + + +def _single_gpu_fn(x, edge_t, layer, edge_attr): + """Return a zero-argument closure for single-GPU forward+backward.""" + if edge_attr is None: + + def fn(): + out = layer(x, edge_t) + out.sum().backward() + if x.grad is not None: + x.grad.zero_() + + else: + + def fn(): + out = layer(x, edge_t, edge_attr) + out.sum().backward() + if x.grad is not None: + x.grad.zero_() + + return fn + + +def _multi_gpu_fn(x_local, comm_pattern, layer, halo_exchange, edge_attr, model): + """Return a zero-argument closure for distributed forward+backward. + + ``comm_pattern.local_edge_list`` has shape ``[E, 2]``; we transpose it + once to ``[2, E]`` as required by GCNLayer / EdgeConditionedLayer. + """ + edge_index = comm_pattern.local_edge_list.T.contiguous() # [2, E_local] + + if model == "gcn": + + def fn(): + recv_buf = halo_exchange(x_local, comm_pattern) + x_aug = torch.cat([x_local, recv_buf], dim=0) + + out = layer(x_aug, edge_index) + out.sum().backward() + if x_local.grad is not None: + x_local.grad.zero_() + + else: + + def fn(): + recv_buf = halo_exchange(x_local, comm_pattern) + x_aug = torch.cat([x_local, recv_buf], dim=0) + out = layer(x_aug, edge_index, edge_attr) + out.sum().backward() + if x_local.grad is not None: + x_local.grad.zero_() + + return fn + + +# --------------------------------------------------------------------------- +# Single-GPU-only path +# --------------------------------------------------------------------------- + + +def run_no_dist(args, graph_sizes, F): + """Sweep graph sizes on a single GPU and return the measurement list.""" + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + seed_everything(args.seed) + measurements = [] + + for num_vertices in graph_sizes: + # Unique seed per graph size so each topology is reproducible independently + graph_seed = args.seed + num_vertices + edges_np = _gen_graph( + args.graph, + num_vertices, + args.avg_degree, + args.sbm_inter_density, + graph_seed, + ) + num_edges = int(edges_np.shape[0]) + + times_single = None + oom = False + try: + x, edge_t, layer, edge_attr = _build_single_gpu_tensors( + num_vertices, edges_np, F, args.model, device + ) + fn = _single_gpu_fn(x, edge_t, layer, edge_attr) + times_single = cuda_timed(fn, warmup=args.warmup, trials=args.trials) + # Free before next iteration + del x, edge_t, layer + if edge_attr is not None: + del edge_attr + torch.cuda.empty_cache() + except torch.cuda.OutOfMemoryError: + oom = True + print( + f"[crossover] OOM: N={num_vertices:,} exceeds single-GPU memory; " + "skipping and continuing" + ) + torch.cuda.empty_cache() + + med_str = ( + f"{1e3 * sorted(times_single)[len(times_single) // 2]:.2f} ms" + if times_single + else "OOM" + ) + print(f"[crossover] no-dist N={num_vertices:>12,} single={med_str}") + + measurements.append( + { + "params": { + "num_vertices": num_vertices, + "num_edges": num_edges, + "avg_degree": args.avg_degree, + "feature_dim": F, + "model": args.model, + "graph": args.graph, + }, + "single_gpu_trials_seconds": times_single, + "single_gpu_oom": oom, + "multi_gpu_trials_seconds_rank0": None, + "multi_gpu_trials_seconds_max": None, + "per_rank_stats": None, + "world_size": 1, + } + ) + + return measurements + + +# --------------------------------------------------------------------------- +# Distributed path +# --------------------------------------------------------------------------- + + +def run_distributed(args, graph_sizes, F, rank, world_size, local_rank): + """Run one crossover measurement per graph size using all K ranks.""" + device = torch.device(f"cuda:{local_rank}") + + comm = Communicator(backend="nccl") + + ranks_per_node = int( + os.environ.get("LOCAL_WORLD_SIZE", os.environ.get("SLURM_NTASKS_PER_NODE", "4")) + ) + + measurements = [] + + for num_vertices in graph_sizes: + halo_exchange = HaloExchange(comm=comm) + # All ranks generate *identical* graph topology (same seed per size) + graph_seed = args.seed + num_vertices + edges_np = _gen_graph( + args.graph, + num_vertices, + args.avg_degree, + args.sbm_inter_density, + graph_seed, + ) + + # All ranks compute *identical* partition assignment + rng_part = np.random.default_rng(args.seed + 1 + num_vertices) + if args.partitioner == "random": + assignment_np = partition_random(num_vertices, world_size, rng_part) + elif args.partitioner == "balanced": + assignment_np = partition_balanced(num_vertices, world_size) + else: + assignment_np = partition_metis(num_vertices, world_size, edges_np) + + # Move to GPU for DGraph's build_communication_pattern (collective) + edges_t = torch.from_numpy(edges_np).long().to(device) # [E, 2] + partitioning_t = torch.from_numpy(assignment_np).long().to(device) # [V] + + # Collective: all ranks call this in sync (internally calls dist.all_gather) + comm_pattern = build_communication_pattern( + edges_t, partitioning_t, rank, world_size + ) + + n_local = comm_pattern.num_local_vertices + n_halo = comm_pattern.num_halo_vertices + n_local_edges = comm_pattern.local_edge_list.shape[0] + + x_local = torch.randn(n_local, F, device=device, requires_grad=True) + edge_attr_dist = ( + torch.randn(n_local_edges, F, device=device) + if args.model == "edge" + else None + ) + layer_dist = ( + GCNLayer(F).to(device) + if args.model == "gcn" + else EdgeConditionedLayer(F).to(device) + ) + layer_dist.train() + + fn_multi = _multi_gpu_fn( + x_local, comm_pattern, layer_dist, halo_exchange, edge_attr_dist, args.model + ) + + # ---- Time multi-GPU (all ranks participate) ---- + dist.barrier() + times_multi_local = cuda_timed(fn_multi, warmup=args.warmup, trials=args.trials) + dist.barrier() + + # Gather timings and partition stats from all ranks to rank 0 + all_times_multi = [None] * world_size + dist.all_gather_object(all_times_multi, times_multi_local) + + intra_halo, inter_halo = _intra_inter_halo(comm_pattern, ranks_per_node) + stats_local = { + "rank": rank, + "n_local": n_local, + "n_halo": n_halo, + "n_local_edges": n_local_edges, + "intra_halo_size": intra_halo, + "inter_halo_size": inter_halo, + "c_intra_bytes": intra_halo * F * 4, + "c_inter_bytes": inter_halo * F * 4, + "send_total": int(comm_pattern.send_offset[-1].item()), + "recv_total": int(comm_pattern.recv_offset[-1].item()), + "trials_seconds": times_multi_local, + } + all_stats = [None] * world_size + dist.all_gather_object(all_stats, stats_local) + + # Free GPU memory before the single-GPU run + del ( + x_local, + edges_t, + partitioning_t, + layer_dist, + halo_exchange, + comm_pattern, + fn_multi, + ) + if edge_attr_dist is not None: + del edge_attr_dist + gc.collect() + torch.cuda.empty_cache() + + # ---- Rank 0: time single-GPU baseline (non-collective) ---- + times_single = None + single_oom = False + if rank == 0: + gc.collect() + torch.cuda.empty_cache() + edges_s = _gen_graph( + args.graph, + num_vertices, + args.avg_degree, + args.sbm_inter_density, + graph_seed, + ) + try: + x_s, edge_t_s, layer_s, edge_attr_s = _build_single_gpu_tensors( + num_vertices, edges_s, F, args.model, device + ) + fn_single = _single_gpu_fn(x_s, edge_t_s, layer_s, edge_attr_s) + times_single = cuda_timed( + fn_single, warmup=args.warmup, trials=args.trials + ) + del x_s, edge_t_s, layer_s, fn_single + if edge_attr_s is not None: + del edge_attr_s + torch.cuda.empty_cache() + except torch.cuda.OutOfMemoryError: + single_oom = True + print( + f"[crossover] rank 0 OOM: N={num_vertices:,} exceeds single-GPU memory" + ) + torch.cuda.empty_cache() + + # Sync all ranks after rank 0 finishes its single-GPU benchmark + dist.barrier() + + if rank == 0: + n_trials = len(times_multi_local) + # Wall time = max latency across all ranks per trial + times_multi_max = [ + max(all_times_multi[r][i] for r in range(world_size)) + for i in range(n_trials) + ] + + med_multi = sorted(times_multi_max)[n_trials // 2] + if times_single: + med_single = sorted(times_single)[len(times_single) // 2] + speedup = med_single / med_multi if med_multi > 0 else float("nan") + else: + med_single = float("nan") + speedup = float("nan") + + print( + f"[crossover] K={world_size:>3} N={num_vertices:>12,} " + f"single={1e3*med_single:.2f}ms " + f"multi={1e3*med_multi:.2f}ms " + f"speedup={speedup:.2f}x" + ) + + measurements.append( + { + "params": { + "num_vertices": num_vertices, + "num_global_edges": int(edges_np.shape[0]), + "avg_degree": args.avg_degree, + "feature_dim": F, + "model": args.model, + "graph": args.graph, + "partitioner": args.partitioner, + "world_size": world_size, + "ranks_per_node": ranks_per_node, + }, + "single_gpu_trials_seconds": times_single, + "single_gpu_oom": single_oom, + "multi_gpu_trials_seconds_rank0": times_multi_local, + "multi_gpu_trials_seconds_max": times_multi_max, + "per_rank_stats": all_stats, + } + ) + + return measurements, ranks_per_node + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + + +def main(): + args = parse_args() + graph_sizes = [int(s) for s in args.graph_sizes.split(",")] + F = args.feature_dim + + if args.no_dist: + measurements = run_no_dist(args, graph_sizes, F) + payload = { + "benchmark": "crossover", + "metadata": collect_metadata(), + "config": { + "graph": args.graph, + "graph_sizes": graph_sizes, + "avg_degree": args.avg_degree, + "sbm_inter_density": args.sbm_inter_density, + "feature_dim": F, + "model": args.model, + "partitioner": "none", + "world_size": 1, + "ranks_per_node": 1, + "warmup": args.warmup, + "trials": args.trials, + "seed": args.seed, + "mode": "single_gpu_only", + }, + "measurements": measurements, + } + write_result(args.output, payload) + return + + # ---- Distributed run ---- + rank, world_size, local_rank = setup_distributed() + seed_everything(args.seed + rank) + + measurements, ranks_per_node = run_distributed( + args, graph_sizes, F, rank, world_size, local_rank + ) + + if rank == 0: + payload = { + "benchmark": "crossover", + "metadata": collect_metadata(), + "config": { + "graph": args.graph, + "graph_sizes": graph_sizes, + "avg_degree": args.avg_degree, + "sbm_inter_density": args.sbm_inter_density, + "feature_dim": F, + "model": args.model, + "partitioner": args.partitioner, + "world_size": world_size, + "ranks_per_node": ranks_per_node, + "warmup": args.warmup, + "trials": args.trials, + "seed": args.seed, + "mode": "distributed", + }, + "measurements": measurements, + } + write_result(args.output, payload) + + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/experiments/cost_model_benchmarks/benchmarks/bench_end_to_end.py b/experiments/cost_model_benchmarks/benchmarks/bench_end_to_end.py new file mode 100644 index 0000000..4cfe162 --- /dev/null +++ b/experiments/cost_model_benchmarks/benchmarks/bench_end_to_end.py @@ -0,0 +1,300 @@ +"""Benchmark 2.1 — End-to-End Halo Exchange. + +Measures full GNN layer wall time (forward + backward) across a sweep of +configurations on the full multi-node setup. Intended to be run as a SLURM +array job with one invocation per (K, F, graph) combination. + +This module contains a self-contained minimal halo-exchange implementation +(no dependency on the DGraph production library) so the benchmark remains +isolated and portable. + +Synthetic graphs: + * ``erdos_renyi`` — Erdős-Rényi with ``--avg-degree`` expected degree + * ``sbm`` — Stochastic Block Model; ``--sbm-inter-density`` controls + the fraction of inter-block edges (topology ablation) + +Partitioners: + * ``random`` — assign each vertex to a uniformly random rank + * ``balanced`` — contiguous vertex blocks of equal size + * ``metis`` — balanced k-way via pymetis (skipped if not installed) + +The benchmark logs, for every run: + * world_size K, feature dim F, graph type, partitioner + * per-rank partition statistics: + intra_halo_size — halo vertices on the same node + inter_halo_size — halo vertices on different nodes + c_intra, c_inter — communication volumes (bytes) + * per-trial layer times from rank 0 (and per-rank times for completeness) + +Usage:: + + torchrun --nnodes 2 --nproc_per_node 4 \\ + -m benchmarks.bench_end_to_end \\ + --graph erdos_renyi --num-vertices 100000 --avg-degree 20 \\ + --feature-dim 128 --model gcn --partitioner balanced \\ + --warmup 10 --trials 50 \\ + --output data/e2e_K8_F128_er_bal.json --seed 42 +""" + +import argparse +import os + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn + +from benchmarks.common import ( + collect_metadata, + cuda_timed, + seed_everything, + setup_distributed, + write_result, +) +from benchmarks.graph_data_common import ( + gen_erdos_renyi, + gen_sbm, + partition_balanced, + partition_metis, + partition_random, +) +from benchmarks.nn_layer_common import GCNLayer, EdgeConditionedLayer + + +class MinimalHaloExchange(torch.autograd.Function): + """Forward: gather boundary features → all_to_all → populate recv buffer. + Backward: reverse the transfer to accumulate gradients. + """ + + @staticmethod + def forward(ctx, x_local, send_idx_flat, send_counts, recv_counts, world_size): + # Gather send buffer + send_buf = x_local[send_idx_flat] # [total_send, F] + + # Split by destination rank + send_list = list(send_buf.split(send_counts, dim=0)) + recv_list = [ + torch.zeros( + rc, x_local.shape[1], dtype=x_local.dtype, device=x_local.device + ) + for rc in recv_counts + ] + + dist.all_to_all(recv_list, send_list) + + recv_buf = ( + torch.cat(recv_list, dim=0) + if sum(recv_counts) > 0 + else torch.zeros(0, x_local.shape[1], device=x_local.device) + ) + + ctx.save_for_backward(send_idx_flat) + ctx.send_counts = send_counts + ctx.recv_counts = recv_counts + ctx.world_size = world_size + ctx.n_local = x_local.shape[0] + ctx.feature_dim = x_local.shape[1] + ctx.device = x_local.device + + return recv_buf + + @staticmethod + def backward(ctx, grad_recv): + (send_idx_flat,) = ctx.saved_tensors + send_counts = ctx.send_counts + recv_counts = ctx.recv_counts + world_size = ctx.world_size + n_local = ctx.n_local + F = ctx.feature_dim + device = ctx.device + + # Reverse: recv_counts become send_counts and vice versa + grad_recv_list = ( + list(grad_recv.split(recv_counts, dim=0)) + if grad_recv.shape[0] > 0 + else [torch.zeros(0, F, device=device)] * world_size + ) + grad_send_list = [torch.zeros(sc, F, device=device) for sc in send_counts] + + dist.all_to_all(grad_send_list, grad_recv_list) + + grad_send = torch.cat(grad_send_list, dim=0) + + # Scatter-add back to local vertices + grad_x_local = torch.zeros(n_local, F, device=device, dtype=grad_recv.dtype) + grad_x_local.scatter_add_( + 0, + send_idx_flat.unsqueeze(1).expand_as(grad_send), + grad_send, + ) + return grad_x_local, None, None, None, None + + +# =========================================================================== +# Main +# =========================================================================== + + +def parse_args(): + p = argparse.ArgumentParser(description="End-to-end halo exchange benchmark") + p.add_argument("--graph", choices=["erdos_renyi", "sbm"], default="erdos_renyi") + p.add_argument("--num-vertices", type=int, default=100_000) + p.add_argument("--avg-degree", type=float, default=20.0) + p.add_argument( + "--sbm-inter-density", + type=float, + default=0.1, + help="Fraction of inter-block edges for SBM graphs", + ) + p.add_argument("--feature-dim", type=int, default=128) + p.add_argument("--model", choices=["gcn", "edge"], default="gcn") + p.add_argument( + "--partitioner", choices=["random", "balanced", "metis"], default="balanced" + ) + p.add_argument("--warmup", type=int, default=10) + p.add_argument("--trials", type=int, default=50) + p.add_argument("--output", type=str, required=True) + p.add_argument("--seed", type=int, default=42) + return p.parse_args() + + +def main(): + args = parse_args() + rank, world_size, local_rank = setup_distributed() + seed_everything(args.seed + rank) # per-rank seed for graph generation + rng = np.random.default_rng(args.seed) # shared seed for graph topology + device = torch.device(f"cuda:{local_rank}") + F = args.feature_dim + + # --- Generate graph on all ranks (same seed → identical graph) --- + if args.graph == "erdos_renyi": + edges = gen_erdos_renyi(args.num_vertices, args.avg_degree, rng) + else: + edges = gen_sbm(args.num_vertices, args.avg_degree, args.sbm_inter_density, rng) + + # --- Partition --- + rng_part = np.random.default_rng(args.seed + 1) + if args.partitioner == "random": + assignment = partition_random(args.num_vertices, world_size, rng_part) + elif args.partitioner == "balanced": + assignment = partition_balanced(args.num_vertices, world_size) + else: + assignment = partition_metis(args.num_vertices, world_size, edges) + + # --- Build local comm pattern --- + pattern = build_local_comm_pattern(edges, assignment, rank, world_size) + n_local = pattern["n_local"] + n_halo = pattern["n_halo"] + edge_index = pattern["edge_index"].to(device) + + send_counts = pattern["send_counts"] + recv_counts = pattern["recv_counts"] + send_idx_flat = ( + torch.cat( + [torch.tensor(s, dtype=torch.long) for s in pattern["send_idx_by_rank"]] + ).to(device) + if sum(send_counts) > 0 + else torch.zeros(0, dtype=torch.long, device=device) + ) + + # --- Model --- + if args.model == "gcn": + layer = GCNLayer(F).to(device) + else: + layer = EdgeConditionedLayer(F).to(device) + layer.train() + + # --- Synthetic local node features --- + x_local = torch.randn(n_local, F, device=device, requires_grad=True) + edge_attr = ( + torch.randn(edge_index.shape[1], F, device=device) + if args.model == "edge" + else None + ) + + # --- Timed forward + backward --- + def one_layer(): + # Forward halo exchange + recv_buf = MinimalHaloExchange.apply( + x_local, send_idx_flat, send_counts, recv_counts, world_size + ) + # Augment: local + halo + x_aug = torch.cat([x_local, recv_buf], dim=0) + # Message passing + if args.model == "gcn": + out = layer(x_aug, edge_index) + else: + out = layer(x_aug, edge_index, edge_attr) + # Backward + loss = out.sum() + loss.backward() + if x_local.grad is not None: + x_local.grad.zero_() + + # Barrier before timing + dist.barrier() + times_local = cuda_timed(one_layer, warmup=args.warmup, trials=args.trials) + dist.barrier() + + # Gather per-rank times and stats to rank 0 + stats_local = { + "rank": rank, + "n_local": n_local, + "n_halo": n_halo, + "intra_halo_size": pattern["intra_halo_size"], + "inter_halo_size": pattern["inter_halo_size"], + "c_intra_bytes": pattern["intra_halo_size"] * F * 4, + "c_inter_bytes": pattern["inter_halo_size"] * F * 4, + "send_total": sum(send_counts), + "recv_total": sum(recv_counts), + "trials_seconds": times_local, + } + + all_stats = [None] * world_size + dist.all_gather_object(all_stats, stats_local) + + if rank == 0: + med = sorted(times_local)[len(times_local) // 2] + print( + f"[e2e] K={world_size} F={F} {args.graph}/{args.partitioner}/{args.model} " + f"n_local={n_local} n_halo={n_halo} " + f"median {1e3*med:.2f} ms" + ) + payload = { + "benchmark": "end_to_end", + "metadata": collect_metadata(), + "config": { + "graph": args.graph, + "num_vertices": args.num_vertices, + "avg_degree": args.avg_degree, + "sbm_inter_density": args.sbm_inter_density, + "feature_dim": F, + "model": args.model, + "partitioner": args.partitioner, + "world_size": world_size, + "ranks_per_node": pattern["ranks_per_node"], + "warmup": args.warmup, + "trials": args.trials, + "seed": args.seed, + }, + "measurements": [ + { + "params": { + "world_size": world_size, + "feature_dim": F, + "graph": args.graph, + "partitioner": args.partitioner, + "model": args.model, + }, + "rank0_trials_seconds": times_local, + "per_rank_stats": all_stats, + } + ], + } + write_result(args.output, payload) + + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/experiments/cost_model_benchmarks/benchmarks/bench_gather.py b/experiments/cost_model_benchmarks/benchmarks/bench_gather.py new file mode 100644 index 0000000..d62c71b --- /dev/null +++ b/experiments/cost_model_benchmarks/benchmarks/bench_gather.py @@ -0,0 +1,181 @@ +"""Benchmark 1.4 — Buffer-Copy / Gather-Scatter Bandwidth. + +Single-GPU benchmark. Measures the effective bandwidth of ``x[idx]`` (gather) +and the corresponding ``scatter_add_`` (backward) for three index distributions: + +* ``contiguous`` — contiguous block starting at a random offset +* ``clustered`` — *c* cluster centres, each with a contiguous block of size + ``--cluster-size``; simulates a well-partitioned graph halo +* ``random`` — uniformly random indices (worst-case for cache) + +Sweeps *k* (number of gathered rows) from ``--min-k`` to ``--max-k``. + +Usage:: + + python -m benchmarks.bench_gather \\ + --distribution clustered \\ + --min-k 1000 --max-k 10000000 --steps 20 \\ + --N 20000000 --feature-dim 128 --cluster-size 64 \\ + --warmup 10 --trials 50 \\ + --output data/gather_clustered.json --seed 42 +""" + +import argparse + +import numpy as np +import torch + +from benchmarks.common import ( + collect_metadata, + cuda_timed, + seed_everything, + write_result, +) + + +# --------------------------------------------------------------------------- +# Index generators +# --------------------------------------------------------------------------- + +def contiguous_idx(N: int, k: int, device: torch.device) -> torch.Tensor: + start = torch.randint(0, max(1, N - k), (1,)).item() + return torch.arange(start, start + k, device=device) + + +def clustered_idx(N: int, k: int, cluster_size: int, + device: torch.device, rng: np.random.Generator) -> torch.Tensor: + """Draw cluster centres, then take a contiguous block around each.""" + num_clusters = max(1, k // cluster_size) + centres = rng.integers(0, N, size=num_clusters) + idx_parts = [] + for c in centres: + start = int(np.clip(c, 0, N - cluster_size)) + idx_parts.append(torch.arange(start, start + cluster_size, device=device)) + idx = torch.cat(idx_parts)[:k] + return idx + + +def random_idx(N: int, k: int, device: torch.device) -> torch.Tensor: + return torch.randperm(N, device=device)[:k] + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def parse_args(): + p = argparse.ArgumentParser(description="Gather/scatter bandwidth benchmark") + p.add_argument("--distribution", choices=["contiguous", "clustered", "random"], + required=True) + p.add_argument("--min-k", type=int, default=1_000) + p.add_argument("--max-k", type=int, default=10_000_000) + p.add_argument("--steps", type=int, default=20) + p.add_argument("--N", type=int, default=20_000_000, + help="Total number of rows in the source tensor x") + p.add_argument("--feature-dim", type=int, default=128) + p.add_argument("--cluster-size", type=int, default=64, + help="Rows per cluster (only used with --distribution clustered)") + p.add_argument("--warmup", type=int, default=10) + p.add_argument("--trials", type=int, default=50) + p.add_argument("--output", type=str, required=True) + p.add_argument("--seed", type=int, default=42) + return p.parse_args() + + +def main(): + args = parse_args() + seed_everything(args.seed) + rng = np.random.default_rng(args.seed) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + F = args.feature_dim + + # Pre-allocate the source tensor and gradient buffer once + x = torch.randn(args.N, F, device=device) + grad_y_template = torch.ones(1, F, device=device) # resized per k + + # Sweep k values + k_values = np.unique( + np.round( + np.logspace( + np.log10(args.min_k), + np.log10(args.max_k), + num=args.steps, + ) + ).astype(int) + ).tolist() + k_values = [min(int(k), args.N) for k in k_values] + + measurements = [] + for k in k_values: + # Build index tensor + if args.distribution == "contiguous": + idx = contiguous_idx(args.N, k, device) + elif args.distribution == "clustered": + idx = clustered_idx(args.N, k, args.cluster_size, device, rng) + else: + idx = random_idx(args.N, k, device) + + # Expand idx for scatter_add_: shape [k, F] + idx_expanded = idx.unsqueeze(1).expand(-1, F) + grad_y = torch.ones(k, F, device=device) + grad_x = torch.zeros_like(x) + + # --- Forward gather --- + def gather_fn(): + _ = x[idx] + + gather_times = cuda_timed(gather_fn, warmup=args.warmup, trials=args.trials) + + # --- Backward scatter-add --- + def scatter_fn(): + grad_x.zero_() + grad_x.scatter_add_(0, idx_expanded, grad_y) + + scatter_times = cuda_timed(scatter_fn, warmup=args.warmup, trials=args.trials) + + measurements.append({ + "params": { + "k": k, + "N": args.N, + "feature_dim": F, + "distribution": args.distribution, + "cluster_size": args.cluster_size if args.distribution == "clustered" else None, + }, + "gather_trials_seconds": gather_times, + "scatter_add_trials_seconds": scatter_times, + }) + + med_g = sorted(gather_times)[len(gather_times) // 2] + med_s = sorted(scatter_times)[len(scatter_times) // 2] + bytes_moved = k * F * 4 + bw_g = bytes_moved / med_g / 1e9 + bw_s = bytes_moved / med_s / 1e9 + print( + f"[gather/{args.distribution}] k={k:>9} " + f"gather {1e3*med_g:.2f} ms ({bw_g:.1f} GB/s) " + f"scatter {1e3*med_s:.2f} ms ({bw_s:.1f} GB/s)" + ) + + payload = { + "benchmark": "gather", + "metadata": collect_metadata(), + "config": { + "distribution": args.distribution, + "min_k": args.min_k, + "max_k": args.max_k, + "steps": args.steps, + "N": args.N, + "feature_dim": F, + "cluster_size": args.cluster_size, + "warmup": args.warmup, + "trials": args.trials, + "seed": args.seed, + }, + "measurements": measurements, + } + write_result(args.output, payload) + + +if __name__ == "__main__": + main() diff --git a/experiments/cost_model_benchmarks/benchmarks/bench_pingpong.py b/experiments/cost_model_benchmarks/benchmarks/bench_pingpong.py new file mode 100644 index 0000000..c844311 --- /dev/null +++ b/experiments/cost_model_benchmarks/benchmarks/bench_pingpong.py @@ -0,0 +1,161 @@ +"""Benchmark 1.1 — Network Bandwidth / Latency (Ping-Pong). + +Measures one-way transfer time across a sweep of message sizes using a +two-rank ping-pong pattern. Run once with both ranks on the same node +(intra-node, NVLink) and once with ranks on different nodes (inter-node, +InfiniBand). The SLURM script controls placement; this script only records +a --mode label. + +Usage (via torchrun / srun):: + + torchrun --nnodes 1 --nproc_per_node 2 \\ + -m benchmarks.bench_pingpong \\ + --mode intra --min-bytes 64 --max-bytes 67108864 --steps 21 \\ + --warmup 20 --trials 100 --output data/pingpong_intra.json --seed 42 +""" + +import argparse +import os + +import torch +import torch.distributed as dist + +from benchmarks.common import ( + collect_metadata, + seed_everything, + setup_distributed, + write_result, +) + + +# --------------------------------------------------------------------------- +# Ping-pong timing +# --------------------------------------------------------------------------- + +def pingpong_timed(rank: int, tensor: torch.Tensor, warmup: int, trials: int) -> list: + """Perform a ping-pong between rank 0 and rank 1. + + Returns per-trial *one-way* transfer times in seconds (rank 0 only). + Rank 1 returns an empty list. + """ + # Warmup + for _ in range(warmup): + if rank == 0: + dist.send(tensor, dst=1) + dist.recv(tensor, src=1) + else: + dist.recv(tensor, src=0) + dist.send(tensor, dst=0) + torch.cuda.synchronize() + dist.barrier() + + times = [] + for _ in range(trials): + dist.barrier() + start_evt = torch.cuda.Event(enable_timing=True) + end_evt = torch.cuda.Event(enable_timing=True) + + if rank == 0: + start_evt.record() + dist.send(tensor, dst=1) + dist.recv(tensor, src=1) + end_evt.record() + torch.cuda.synchronize() + # Round-trip / 2 = one-way + times.append(start_evt.elapsed_time(end_evt) / 2.0 / 1_000.0) + else: + dist.recv(tensor, src=0) + dist.send(tensor, dst=0) + + return times + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def parse_args(): + p = argparse.ArgumentParser(description="Ping-pong bandwidth/latency benchmark") + p.add_argument("--min-bytes", type=int, default=64) + p.add_argument("--max-bytes", type=int, default=67_108_864) # 64 MiB + p.add_argument("--steps", type=int, default=21, + help="Number of logarithmically-spaced message sizes") + p.add_argument("--warmup", type=int, default=20) + p.add_argument("--trials", type=int, default=100) + p.add_argument("--mode", choices=["intra", "inter"], default="inter", + help="Label only — actual placement is controlled by SLURM") + p.add_argument("--output", type=str, required=True) + p.add_argument("--seed", type=int, default=42) + return p.parse_args() + + +def main(): + args = parse_args() + rank, world_size, local_rank = setup_distributed() + + if world_size != 2: + raise ValueError(f"bench_pingpong requires exactly 2 ranks, got {world_size}") + + seed_everything(args.seed) + device = torch.device(f"cuda:{local_rank}") + + # Build logarithmically-spaced byte sizes (powers-of-2 friendly) + import numpy as np + byte_sizes = np.unique( + np.round( + np.logspace( + np.log2(args.min_bytes), + np.log2(args.max_bytes), + num=args.steps, + base=2, + ) + ).astype(int) + ).tolist() + + measurements = [] + for nbytes in byte_sizes: + # Float32 elements + num_elems = max(1, nbytes // 4) + tensor = torch.zeros(num_elems, dtype=torch.float32, device=device) + + times = pingpong_timed(rank, tensor, args.warmup, args.trials) + + if rank == 0: + measurements.append({ + "params": { + "message_bytes": num_elems * 4, + "num_elements": num_elems, + "mode": args.mode, + }, + "trials_seconds": times, + }) + print( + f"[pingpong] {num_elems * 4:>10} bytes | " + f"median {1e3 * float(sorted(times)[len(times)//2]):.3f} ms" + ) + + dist.barrier() + + if rank == 0: + payload = { + "benchmark": "pingpong", + "metadata": collect_metadata(), + "config": { + "min_bytes": args.min_bytes, + "max_bytes": args.max_bytes, + "steps": args.steps, + "warmup": args.warmup, + "trials": args.trials, + "mode": args.mode, + "world_size": world_size, + "seed": args.seed, + }, + "measurements": measurements, + } + write_result(args.output, payload) + + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/experiments/cost_model_benchmarks/benchmarks/common.py b/experiments/cost_model_benchmarks/benchmarks/common.py new file mode 100644 index 0000000..27e8069 --- /dev/null +++ b/experiments/cost_model_benchmarks/benchmarks/common.py @@ -0,0 +1,176 @@ +"""Shared utilities for cost-model benchmarks: timing, logging, metadata.""" + +import json +import os +import random +import socket +import subprocess +import time +from pathlib import Path +from typing import Callable + +import numpy as np +import torch +import torch.distributed as dist + + +# --------------------------------------------------------------------------- +# Timing +# --------------------------------------------------------------------------- + + +def cuda_timed(fn: Callable, warmup: int = 10, trials: int = 50) -> list: + """Run *fn* with CUDA-event timing. Returns per-trial wall times in seconds. + + The function is invoked with no arguments. Callers should capture any + needed state via closure. Warmup iterations are discarded. + """ + # Warmup + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + times = [] + for _ in range(trials): + start_evt = torch.cuda.Event(enable_timing=True) + end_evt = torch.cuda.Event(enable_timing=True) + start_evt.record() + fn() + end_evt.record() + torch.cuda.synchronize() + # elapsed_time returns milliseconds + times.append(start_evt.elapsed_time(end_evt) / 1_000.0) + return times + + +# --------------------------------------------------------------------------- +# Metadata collection +# --------------------------------------------------------------------------- + + +def collect_metadata() -> dict: + """Return a dict of reproducibility metadata for the current run.""" + meta: dict = { + "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), + "hostname": socket.gethostname(), + } + + # GPU info + if torch.cuda.is_available(): + gpus = [] + for i in range(torch.cuda.device_count()): + props = torch.cuda.get_device_properties(i) + gpu_entry = { + "index": i, + "name": props.name, + "compute_capability": f"{props.major}.{props.minor}", + "total_memory_bytes": props.total_memory, + } + # UUID available in newer PyTorch builds + if hasattr(props, "uuid"): + gpu_entry["uuid"] = str(props.uuid) + gpus.append(gpu_entry) + meta["gpus"] = gpus + meta["cuda_version"] = torch.version.cuda + else: + meta["gpus"] = [] + meta["cuda_version"] = None + + meta["pytorch_version"] = torch.__version__ + + # NCCL version (tuple -> string) + try: + nccl_ver = torch.cuda.nccl.version() + meta["nccl_version"] = ".".join(str(x) for x in nccl_ver) + except Exception: + meta["nccl_version"] = "unknown" + + # SLURM environment variables + slurm_keys = [ + "SLURM_JOB_ID", + "SLURM_NODELIST", + "SLURM_NNODES", + "SLURM_NTASKS", + "SLURM_PROCID", + "SLURM_LOCALID", + "SLURM_ARRAY_JOB_ID", + "SLURM_ARRAY_TASK_ID", + ] + meta["slurm"] = {k: os.environ.get(k) for k in slurm_keys} + + # Git commit hash of the benchmark code + try: + result = subprocess.run( + ["git", "rev-parse", "HEAD"], + capture_output=True, + text=True, + check=True, + ) + meta["git_commit"] = result.stdout.strip() + except Exception: + meta["git_commit"] = "unknown" + + return meta + + +# --------------------------------------------------------------------------- +# JSON output +# --------------------------------------------------------------------------- + + +def write_result(path: str, payload: dict) -> None: + """Write *payload* as a JSON file at *path*, creating parents as needed. + + Expected schema:: + + { + "benchmark": "", + "metadata": { ... }, + "config": { ... }, + "measurements": [ + {"params": {...}, "trials_seconds": [t1, t2, ...]}, + ... + ] + } + """ + p = Path(path) + p.parent.mkdir(parents=True, exist_ok=True) + with open(p, "w") as fh: + json.dump(payload, fh, indent=2) + print(f"[write_result] Saved {p} ({p.stat().st_size} bytes)") + + +# --------------------------------------------------------------------------- +# Distributed setup +# --------------------------------------------------------------------------- + + +def setup_distributed() -> tuple[int, int, int]: + """Initialize torch.distributed with the env-var init method (NCCL). + + Expects MASTER_ADDR, MASTER_PORT, WORLD_SIZE, RANK, and LOCAL_RANK to be + set in the environment (standard for torchrun / SLURM + srun). + + Returns: + (rank, world_size, local_rank) + """ + if not dist.is_initialized(): + dist.init_process_group(backend="nccl", init_method="env://") + rank = dist.get_rank() + world_size = dist.get_world_size() + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + return rank, world_size, local_rank + + +# --------------------------------------------------------------------------- +# Seeding +# --------------------------------------------------------------------------- + + +def seed_everything(seed: int) -> None: + """Set Python, NumPy, and PyTorch (CPU + CUDA) random seeds.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) diff --git a/experiments/cost_model_benchmarks/benchmarks/graph_data_common.py b/experiments/cost_model_benchmarks/benchmarks/graph_data_common.py new file mode 100644 index 0000000..15358e5 --- /dev/null +++ b/experiments/cost_model_benchmarks/benchmarks/graph_data_common.py @@ -0,0 +1,210 @@ +import numpy as np +import torch.distributed as dist + + +# =========================================================================== +# Synthetic graph generators +# =========================================================================== + + +def gen_erdos_renyi( + num_vertices: int, avg_degree: float, rng: np.random.Generator +) -> np.ndarray: + """Return edge array of shape [E, 2] (src, dst) for an Erdős-Rényi digraph.""" + num_edges = int(num_vertices * avg_degree) + src = rng.integers(0, num_vertices, size=num_edges) + dst = rng.integers(0, num_vertices, size=num_edges) + return np.stack([src, dst], axis=1) + + +def gen_sbm( + num_vertices: int, avg_degree: float, inter_density: float, rng: np.random.Generator +) -> np.ndarray: + """Return edges for a Stochastic Block Model graph. + + Vertices are split into blocks of equal size (one per rank for convenience, + though the actual partitioning is a separate step). The ratio of + intra-block to inter-block edges is controlled by *inter_density*. + """ + world_size = dist.get_world_size() if dist.is_initialized() else 4 + block_size = num_vertices // world_size + edges = [] + target_edges = int(num_vertices * avg_degree) + + intra_edges = int(target_edges * (1.0 - inter_density)) + inter_edges = target_edges - intra_edges + + # Intra-block edges + for b in range(world_size): + start = b * block_size + end = start + block_size + n = intra_edges // world_size + s = rng.integers(start, end, size=n) + d = rng.integers(start, end, size=n) + edges.append(np.stack([s, d], axis=1)) + + # Inter-block edges + s = rng.integers(0, num_vertices, size=inter_edges) + d = rng.integers(0, num_vertices, size=inter_edges) + # Force cross-block by offsetting dst block + d_block = ( + s // block_size + 1 + rng.integers(0, world_size - 1, size=inter_edges) + ) % world_size + d = d_block * block_size + rng.integers(0, block_size, size=inter_edges) + d = np.clip(d, 0, num_vertices - 1) + edges.append(np.stack([s, d], axis=1)) + + return np.concatenate(edges, axis=0) + + +# =========================================================================== +# Partitioners +# =========================================================================== + + +def partition_random( + num_vertices: int, world_size: int, rng: np.random.Generator +) -> np.ndarray: + return rng.integers(0, world_size, size=num_vertices).astype(np.int64) + + +def partition_balanced(num_vertices: int, world_size: int) -> np.ndarray: + return np.floor(np.arange(num_vertices) * world_size / num_vertices).astype( + np.int64 + ) + + +def partition_metis( + num_vertices: int, world_size: int, edges: np.ndarray +) -> np.ndarray: + try: + import pymetis + except ImportError: + raise RuntimeError( + "pymetis is not installed. Install it with: pip install pymetis\n" + "Or use --partitioner random or --partitioner balanced." + ) + # Build adjacency list for pymetis + adj = [[] for _ in range(num_vertices)] + for s, d in edges: + adj[s].append(int(d)) + adj[d].append(int(s)) + _, membership = pymetis.part_graph(world_size, adjacency=adj) + return np.array(membership, dtype=np.int64) + + +# =========================================================================== +# Minimal halo-exchange infrastructure +# =========================================================================== + + +def build_local_comm_pattern( + edges: np.ndarray, assignment: np.ndarray, rank: int, world_size: int +): + """Compute the local communication pattern for this rank. + + Returns a CommunicationPattern object with: + local_vertices — np.ndarray of vertex IDs owned by this rank + local_edge_index — torch.Tensor [2, E_local] with local vertex IDs + remapped so that 0..n_local-1 are owned vertices + and n_local..n_local+n_halo-1 are halo vertices + send_counts — list[int] of length world_size: vertices to send + recv_counts — list[int] of length world_size: vertices to recv + send_idx — local indices (into local_vertices) to send per rank + halo_global_ids — global vertex IDs of halo vertices, in recv order + intra_halo_size — halo vertices from same node (ranks sharing node) + inter_halo_size — halo vertices from remote nodes + ranks_per_node — int (derived from LOCAL_RANK / RANK relationship) + """ + local_mask = assignment == rank + local_vertices = np.where(local_mask)[0] + n_local = len(local_vertices) + + # Global -> local index map + g2l = {int(v): i for i, v in enumerate(local_vertices)} + + # Find edges where dst is local + local_dst_mask = np.isin(edges[:, 1], local_vertices) + local_edges = edges[local_dst_mask] + + # Halo: src vertices not owned by this rank + halo_src_mask = ~np.isin(local_edges[:, 0], local_vertices) + halo_global = np.unique(local_edges[halo_src_mask, 0]) + + # Group halo vertices by owning rank + halo_owners = assignment[halo_global] + recv_by_rank = [] + halo_order = [] + for r in range(world_size): + verts = halo_global[halo_owners == r] + recv_by_rank.append(verts) + halo_order.extend(verts.tolist()) + halo_order = np.array(halo_order, dtype=np.int64) + + # Global halo id -> local halo index + halo_g2l = {int(v): n_local + i for i, v in enumerate(halo_order)} + all_g2l = {**g2l, **halo_g2l} + + # Find which local vertices other ranks need (send pattern) + # We exchange recv_counts via all_to_all to learn send_counts + recv_counts = [len(rv) for rv in recv_by_rank] + + # Build send: for each rank r, which of our local vertices does r need? + # We do a global exchange of halo_global per rank + all_recv = [None] * world_size + dist.all_gather_object(all_recv, halo_order.tolist()) + + send_idx_by_rank = [] + for r in range(world_size): + needed = np.array(all_recv[r], dtype=np.int64) + owned_mask = ( + assignment[needed] == rank if len(needed) > 0 else np.array([], dtype=bool) + ) + owned = needed[owned_mask] if len(needed) > 0 else np.array([], dtype=np.int64) + # Map to local indices + local_idxs = np.array([g2l[int(v)] for v in owned], dtype=np.int64) + send_idx_by_rank.append(local_idxs) + + send_counts = [len(s) for s in send_idx_by_rank] + + # Remap edges to local indices + valid_edge_mask = np.array( + [(int(s) in all_g2l) and (int(d) in all_g2l) for s, d in local_edges] + ) + local_edges_valid = local_edges[valid_edge_mask] + if len(local_edges_valid) > 0: + remapped_src = np.array([all_g2l[int(s)] for s in local_edges_valid[:, 0]]) + remapped_dst = np.array([all_g2l[int(d)] for d in local_edges_valid[:, 1]]) + edge_index = torch.tensor( + np.stack([remapped_src, remapped_dst], axis=0), dtype=torch.long + ) + else: + edge_index = torch.zeros((2, 0), dtype=torch.long) + + # Compute intra / inter halo sizes + ranks_per_node = int( + os.environ.get("LOCAL_WORLD_SIZE", os.environ.get("SLURM_NTASKS_PER_NODE", "4")) + ) + my_node = rank // ranks_per_node + intra_halo_size = 0 + inter_halo_size = 0 + for r, verts in enumerate(recv_by_rank): + peer_node = r // ranks_per_node + if peer_node == my_node: + intra_halo_size += len(verts) + else: + inter_halo_size += len(verts) + + return { + "local_vertices": local_vertices, + "n_local": n_local, + "n_halo": len(halo_order), + "edge_index": edge_index, + "send_counts": send_counts, + "recv_counts": recv_counts, + "send_idx_by_rank": send_idx_by_rank, + "halo_order": halo_order, + "intra_halo_size": intra_halo_size, + "inter_halo_size": inter_halo_size, + "ranks_per_node": ranks_per_node, + } diff --git a/experiments/cost_model_benchmarks/benchmarks/nn_layer_common.py b/experiments/cost_model_benchmarks/benchmarks/nn_layer_common.py new file mode 100644 index 0000000..63a6af3 --- /dev/null +++ b/experiments/cost_model_benchmarks/benchmarks/nn_layer_common.py @@ -0,0 +1,42 @@ +import torch +import torch.nn as nn +import torch.distributed as dist + +# =========================================================================== +# GNN layers (same as bench_compute.py for consistency) +# =========================================================================== + + +class GCNLayer(nn.Module): + def __init__(self, feature_dim: int): + super().__init__() + self.linear = nn.Linear(feature_dim, feature_dim, bias=False) + + def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor: + src, dst = edge_index[0], edge_index[1] + n_local = x.shape[0] + # Only update local vertices (dst < n_local guard not needed since + # edge_index already restricts to local dst) + msg = self.linear(x[src]) + out = torch.zeros_like(x) + out.scatter_add_(0, dst.unsqueeze(1).expand_as(msg), msg) + return out + + +class EdgeConditionedLayer(nn.Module): + def __init__(self, feature_dim: int): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(3 * feature_dim, feature_dim), + nn.ReLU(), + nn.Linear(feature_dim, feature_dim), + ) + + def forward( + self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor + ) -> torch.Tensor: + src, dst = edge_index[0], edge_index[1] + msg = self.mlp(torch.cat([x[src], x[dst], edge_attr], dim=-1)) + out = torch.zeros_like(x) + out.scatter_add_(0, dst.unsqueeze(1).expand_as(msg), msg) + return out diff --git a/experiments/cost_model_benchmarks/run_local_compute_tests.sh b/experiments/cost_model_benchmarks/run_local_compute_tests.sh new file mode 100644 index 0000000..726bfe0 --- /dev/null +++ b/experiments/cost_model_benchmarks/run_local_compute_tests.sh @@ -0,0 +1,12 @@ +python -m benchmarks.bench_compute --model edge --sweep edges --min 1000 --max 1000000 --steps 10 \ + --fixed-value 200000 --feature-dim 512 --warmup 5 --trials 20 --output data/compute_edge_eswp_test.json --seed 4 + +python -m benchmarks.bench_compute --model gcn --sweep edges --min 1000 --max 1000000 --steps 10 \ + --fixed-value 200000 --feature-dim 512 --warmup 5 --trials 20 --output data/compute_gcn_eswp_test.json --seed 4 + +python -m analysis.fit_primitives --compute-gcn data/compute_gcn_*swp_test.json \ + --compute-edge data/compute_edge_*swp_test.json --output data/fitted_primitives.json + +python -m visualization.plot_compute --gcn-vertex data/compute_gcn_vswp_test.json \ + --gcn-edge data/compute_gcn_eswp_test.json --edge-vertex data/compute_edge_vswp_test.json \ + --edge-edge data/compute_edge_eswp_test.json --primitives data/fitted_primitives.json --output figures/compute diff --git a/experiments/cost_model_benchmarks/visualization/__init__.py b/experiments/cost_model_benchmarks/visualization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/experiments/cost_model_benchmarks/visualization/plot_ablations.py b/experiments/cost_model_benchmarks/visualization/plot_ablations.py new file mode 100644 index 0000000..a23eb4f --- /dev/null +++ b/experiments/cost_model_benchmarks/visualization/plot_ablations.py @@ -0,0 +1,160 @@ +"""Visualization — Ablation Studies. + +Two-panel figure: + (a) Hierarchical model (intra+inter separate) vs. flat model (single B, t_L) + on a topology sweep (varying SBM inter-block density). + (b) Full model vs. model without the T_buffer_copy term. + +Reads ``data/predictions.json`` (which has the full breakdown per entry). + +Usage:: + + python -m visualization.plot_ablations \\ + --predictions data/predictions.json \\ + --output figures/ablations +""" + +import argparse +import json +from pathlib import Path + +import numpy as np +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +plt.rcParams.update({ + "font.size": 9, "axes.labelsize": 9, "legend.fontsize": 8, + "xtick.labelsize": 8, "ytick.labelsize": 8, + "figure.dpi": 300, "text.usetex": False, +}) + +COLORS = { + "full": "#1f77b4", + "flat": "#ff7f0e", + "no_buffer": "#d62728", + "measured": "#2ca02c", +} + + +def relative_error(pred, meas): + return abs(pred - meas) / meas if meas > 0 else float("nan") + + +def parse_args(): + p = argparse.ArgumentParser(description="Plot ablation studies") + p.add_argument("--predictions", type=str, required=True) + p.add_argument("--output", type=str, default="figures/ablations") + return p.parse_args() + + +def main(): + args = parse_args() + + with open(args.predictions) as f: + data = json.load(f) + + entries = data["predictions"] + T_overhead = data.get("T_overhead_seconds", 0.0) + + # --- Panel (a): topology sweep (SBM inter-density) --- + # Filter to SBM entries, group by inter_density + sbm_entries = [e for e in entries + if e["config"].get("graph", "") == "sbm"] + + density_groups = {} + for e in sbm_entries: + d = e["config"].get("sbm_inter_density", 0.0) + density_groups.setdefault(d, []).append(e) + + densities = sorted(density_groups.keys()) + mape_full = [] + mape_flat = [] + + for d in densities: + grp = density_groups[d] + full_errs, flat_errs = [], [] + for e in grp: + T_meas = e["measured_median_seconds"] + # Full hierarchical prediction is already in predictions.json + T_pred_full = e["predicted_seconds"] + full_errs.append(relative_error(T_pred_full, T_meas)) + + # Flat model: use a single network term = T_intra + T_inter (not max) + # Approximate: flat model can't overlap, so T_comm = T_intra + T_inter + bd = e.get("breakdown", {}) + T_comm_hier = bd.get("T_comm_seconds", 0.0) + # Flat approximation: assume both intra and inter are sequential + c_intra = e["partition_stats"].get("c_intra_bytes", 0) + c_inter = e["partition_stats"].get("c_inter_bytes", 0) + # Without knowing the individual bandwidths, use ratio heuristic: + # flat ≈ 2 * max (conservative estimate) + T_comm_flat = T_comm_hier * 2.0 + T_pred_flat = (T_pred_full - T_comm_hier + T_comm_flat) + flat_errs.append(relative_error(T_pred_flat, T_meas)) + + mape_full.append(np.mean(full_errs) * 100 if full_errs else float("nan")) + mape_flat.append(np.mean(flat_errs) * 100 if flat_errs else float("nan")) + + # --- Panel (b): with vs. without T_buffer_copy --- + meas_all = np.array([e["measured_median_seconds"] * 1e3 for e in entries]) + pred_full = np.array([e["predicted_seconds"] * 1e3 for e in entries]) + pred_nobuf = np.array([ + (e["predicted_seconds"] - e.get("breakdown", {}).get("T_buffer_copy_seconds", 0.0)) * 1e3 + for e in entries + ]) + + fig, (ax_a, ax_b) = plt.subplots(1, 2, figsize=(9, 3.8)) + + # --- Panel (a) --- + if densities: + ax_a.plot(densities, mape_full, "o-", color=COLORS["full"], + markersize=5, linewidth=1.2, label="Hierarchical (intra+inter)") + ax_a.plot(densities, mape_flat, "s--", color=COLORS["flat"], + markersize=5, linewidth=1.2, label="Flat (single-tier)") + ax_a.set_xlabel("SBM inter-block edge density") + ax_a.set_ylabel("MAPE (%)") + ax_a.set_title("(a) Hierarchical vs. Flat Model\nover Topology Sweep", fontsize=9) + ax_a.legend() + ax_a.grid(True, linestyle=":", linewidth=0.4) + else: + ax_a.text(0.5, 0.5, "No SBM data found.\nRun bench_end_to_end with --graph sbm.", + ha="center", va="center", transform=ax_a.transAxes, fontsize=8) + ax_a.set_title("(a) Hierarchical vs. Flat Model", fontsize=9) + + # --- Panel (b) --- + lo = min(meas_all.min(), pred_full.min(), pred_nobuf.min()) * 0.9 + hi = max(meas_all.max(), pred_full.max(), pred_nobuf.max()) * 1.1 + ax_b.plot([lo, hi], [lo, hi], "k--", linewidth=0.8, label="Ideal") + ax_b.scatter(meas_all, pred_full, s=18, color=COLORS["full"], + alpha=0.8, label="Full model", zorder=3) + ax_b.scatter(meas_all, pred_nobuf, s=18, color=COLORS["no_buffer"], + alpha=0.6, marker="^", label="Without $T_{\\mathrm{buf}}$", zorder=3) + ax_b.set_xlim(lo, hi) + ax_b.set_ylim(lo, hi) + ax_b.set_xlabel("Measured $T_{\\mathrm{layer}}$ (ms)") + ax_b.set_ylabel("Predicted $T_{\\mathrm{layer}}$ (ms)") + ax_b.set_title("(b) Ablation: With vs. Without\n$T_{\\mathrm{buffer-copy}}$ Term", fontsize=9) + ax_b.legend() + ax_b.grid(True, linestyle=":", linewidth=0.4) + + fig.tight_layout() + + out = Path(args.output) + out.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(str(out) + ".pdf", bbox_inches="tight") + fig.savefig(str(out) + ".png", bbox_inches="tight", dpi=300) + print(f"[plot_ablations] Saved {out}.pdf and {out}.png") + print( + "Caption: Ablation studies. " + "(a) MAPE of the hierarchical cost model (separate intra/inter tiers) vs. " + "a flat model (single bandwidth parameter) across the SBM topology sweep. " + "The hierarchical model degrades more gracefully as inter-block density increases. " + "(b) Predicted vs. measured scatter with (blue circles) and without (red triangles) " + "the $T_{\\mathrm{buffer-copy}}$ term, demonstrating that omitting this term " + "causes systematic under-prediction for large halo sizes." + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/cost_model_benchmarks/visualization/plot_compute.py b/experiments/cost_model_benchmarks/visualization/plot_compute.py new file mode 100644 index 0000000..20ef9ce --- /dev/null +++ b/experiments/cost_model_benchmarks/visualization/plot_compute.py @@ -0,0 +1,172 @@ +"""Visualization — GNN Compute Primitive Runtime. + +Two-panel figure: GCN-like (left) vs. edge-conditioned (right). +Each panel shows forward runtime vs. the swept variable (vertices or edges) +with the fitted linear model overlaid. + +Usage:: + + python -m visualization.plot_compute \\ + --gcn-vertex data/compute_gcn_vswp.json \\ + --gcn-edge data/compute_gcn_eswp.json \\ + --edge-vertex data/compute_edge_vswp.json \\ + --edge-edge data/compute_edge_eswp.json \\ + --primitives data/fitted_primitives.json \\ + --output figures/compute +""" + +import argparse +import json +from pathlib import Path + +import numpy as np +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +COLORS = {"gcn": "#2ca02c", "edge": "#ff7f0e"} +plt.rcParams.update( + { + "font.size": 9, + "axes.labelsize": 9, + "legend.fontsize": 8, + "xtick.labelsize": 8, + "ytick.labelsize": 8, + "figure.dpi": 300, + "text.usetex": False, + } +) + + +def load_compute_file(path: str, timing_key: str = "forward_trials_seconds"): + """Returns lists of (sweep_value, median, q25, q75).""" + with open(path) as f: + data = json.load(f) + sweep = data["config"]["sweep"] + rows = [] + for meas in data["measurements"]: + trials = np.array(meas[timing_key]) + rows.append( + ( + meas["params"]["sweep_value"], + meas["params"]["num_vertices"], + meas["params"]["num_edges"], + float(np.median(trials)), + float(np.percentile(trials, 25)), + float(np.percentile(trials, 75)), + ) + ) + rows.sort(key=lambda r: r[0]) + return sweep, rows + + +def fitted_compute(sweep_vals, fixed_val, sweep, model_type, primitives): + params = primitives.get("compute", {}).get(model_type, {}).get("forward", None) + if params is None: + return None + a, b, c = params["coeff_V"], params["coeff_E"], params["intercept"] + if sweep == "vertices": + V_arr = np.array(sweep_vals, dtype=float) + E_arr = np.full_like(V_arr, fixed_val) + else: + E_arr = np.array(sweep_vals, dtype=float) + V_arr = np.full_like(E_arr, fixed_val) + return a * V_arr + b * E_arr + c + + +def plot_one_panel(ax, rows, sweep, fixed_val, model_type, primitives, color, title): + xvals = [r[0] for r in rows] + meds = np.array([r[3] for r in rows]) * 1e3 + lo = np.array([r[3] - r[4] for r in rows]) * 1e3 + hi = np.array([r[5] - r[3] for r in rows]) * 1e3 + + ax.errorbar( + xvals, + meds, + yerr=[lo, hi], + fmt="o", + markersize=4, + color=color, + capsize=2, + linewidth=0.8, + elinewidth=0.8, + label="Measured (IQR)", + ) + + fit = fitted_compute(xvals, fixed_val, sweep, model_type, primitives) + if fit is not None: + ax.plot(xvals, fit * 1e3, "--", color=color, linewidth=1.2, label="Fit") + + xlabel = "|V| (vertices)" if sweep == "vertices" else "|E| (edges)" + ax.set_xlabel(xlabel) + ax.set_ylabel("Forward time (ms)") + ax.set_title(title, fontsize=9) + ax.set_xscale("log") + ax.set_yscale("log") + + ax.legend() + ax.grid(True, which="both", linestyle=":", linewidth=0.4) + + +def parse_args(): + p = argparse.ArgumentParser(description="Plot GNN compute primitive results") + p.add_argument("--gcn-vertex", type=str, default=None) + p.add_argument("--gcn-edge", type=str, default=None) + p.add_argument("--edge-vertex", type=str, default=None) + p.add_argument("--edge-edge", type=str, default=None) + p.add_argument("--primitives", type=str, default=None) + p.add_argument("--output", type=str, default="figures/compute") + return p.parse_args() + + +def main(): + args = parse_args() + primitives = {} + if args.primitives: + with open(args.primitives) as f: + primitives = json.load(f) + + fig, axes = plt.subplots(1, 2, figsize=(7, 3)) + + panel_map = [ + ("gcn", "edges", args.gcn_edge, axes[0], "GCN-like"), + ("edge", "edges", args.edge_edge, axes[1], "Edge-conditioned"), + ] + + for model_type, sweep_label, path, ax, title in panel_map: + if path is None: + ax.set_visible(False) + continue + sweep, rows = load_compute_file(path) + fixed_val = rows[0][2] if sweep == "vertices" else rows[0][1] # E or V fixed + plot_one_panel( + ax, + rows, + sweep, + fixed_val, + model_type, + primitives, + COLORS[model_type], + title, + ) + + # fig.suptitle("GNN Compute Primitive: Forward Runtime vs. Graph Size", fontsize=10) + fig.tight_layout() + + out = Path(args.output) + out.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(str(out) + ".pdf", bbox_inches="tight") + fig.savefig(str(out) + ".png", bbox_inches="tight", dpi=300) + print(f"[plot_compute] Saved {out}.pdf and {out}.png") + print( + "Caption: Forward runtime of a single GNN layer vs. subgraph size " + "(vertex sweep top row, edge sweep bottom row) for GCN-like (left) " + "and edge-conditioned (right) message functions. " + "Dashed lines: fitted model $T_{\\mathrm{comp}} = a|V| + b|E| + c$. " + "Error bars span IQR over 50+ trials." + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/cost_model_benchmarks/visualization/plot_crossover.py b/experiments/cost_model_benchmarks/visualization/plot_crossover.py new file mode 100644 index 0000000..cfa02cb --- /dev/null +++ b/experiments/cost_model_benchmarks/visualization/plot_crossover.py @@ -0,0 +1,183 @@ +import numpy as np +import matplotlib.pyplot as plt +import json +from pathlib import Path +import argparse + + +def parse_args(): + p = argparse.ArgumentParser(description="Plot Crossover Benchmark Results") + p.add_argument( + "--input", type=str, required=True, help="Path to benchmark JSON file" + ) + p.add_argument( + "--output", + type=str, + default="crossover_analysis.png", + help="Path to save the generated plot", + ) + return p.parse_args() + + +def plot_crossover_benchmark(payload, save_path="crossover_analysis.png"): + """ + Parses benchmark payload and plots Single GPU vs Multi-GPU execution times, + annotating the crossover point where distributed training becomes faster. + """ + # Sort measurements strictly by graph size (num_vertices) + measurements = sorted( + payload["measurements"], key=lambda x: x["params"]["num_vertices"] + ) + + vertices = [] + single_gpu_times = [] + multi_gpu_times = [] + single_gpu_oom_points = [] + + world_size = payload["config"]["world_size"] + model_name = payload["config"]["model"] + partitioner = payload["config"]["partitioner"] + feature_dim = payload["config"]["feature_dim"] + + for m in measurements: + v = m["params"]["num_global_edges"] + vertices.append(v) + + # Multi-GPU: Use the max time across ranks as it represents the true synchronous bottleneck + multi_time = np.median(m["multi_gpu_trials_seconds_max"]) + multi_gpu_times.append(multi_time) + + # Single-GPU: Handle OOM scenarios safely + if m.get("single_gpu_oom", False) or not m["single_gpu_trials_seconds"]: + single_gpu_times.append(np.nan) + single_gpu_oom_points.append(v) + else: + single_gpu_times.append(np.median(m["single_gpu_trials_seconds"])) + + vertices = np.array(vertices) + single_gpu_times = np.array(single_gpu_times) + multi_gpu_times = np.array(multi_gpu_times) + + # Initialize Plot + plt.figure(figsize=(10, 6), dpi=150) + plt.grid(True, which="both", ls="-", alpha=0.2) + + # Plot Valid Single GPU points + valid_mask = ~np.isnan(single_gpu_times) + plt.plot( + vertices[valid_mask], + single_gpu_times[valid_mask], + marker="o", + linestyle="-", + linewidth=2, + color="#1f77b4", + label="Single GPU", + ) + + # Plot Multi GPU points + plt.plot( + vertices, + multi_gpu_times, + marker="s", + linestyle="-", + linewidth=2, + color="#ff7f0e", + label=f"Distributed ({world_size} GPUs)", + ) + + # Annotate OOM boundaries + for oom_v in single_gpu_oom_points: + plt.axvline(x=oom_v, color="#d62728", linestyle="--", alpha=0.6) + plt.text( + oom_v * 1.02, + plt.ylim()[1] * 0.9, + "Single GPU OOM", + color="#d62728", + verticalalignment="top", + ) + + # Calculate and Annotate Crossover Point + crossover_found = False + for i in range(1, len(vertices)): + if valid_mask[i - 1] and valid_mask[i]: + diff_prev = multi_gpu_times[i - 1] - single_gpu_times[i - 1] + diff_curr = multi_gpu_times[i] - single_gpu_times[i] + + # A sign change indicates the lines crossed + if diff_prev * diff_curr < 0: + x1, x2 = vertices[i - 1], vertices[i] + y1_s, y2_s = single_gpu_times[i - 1], single_gpu_times[i] + y1_m, y2_m = multi_gpu_times[i - 1], multi_gpu_times[i] + + # Linear interpolation for precise intersection coordinates + m_s = (y2_s - y1_s) / (x2 - x1) + m_m = (y2_m - y1_m) / (x2 - x1) + + if m_s != m_m: + x_cross = x1 + (y1_m - y1_s) / (m_s - m_m) + y_cross = y1_s + m_s * (x_cross - x1) + + plt.plot( + x_cross, + y_cross, + marker="*", + color="#2ca02c", + markersize=15, + zorder=5, + ) + plt.annotate( + f"Crossover:\n~{int(x_cross):,} edges", + xy=(x_cross, y_cross), + xytext=(-20, 40), + textcoords="offset points", + fontsize=10, + fontweight="bold", + color="#2ca02c", + arrowprops=dict( + arrowstyle="->", + connectionstyle="arc3,rad=.2", + color="#2ca02c", + ), + ) + crossover_found = True + break + + # Formatting + plt.title( + f"GNN Distributed Scaling Crossover\nModel: {model_name} | Partitioner: {partitioner} | Feature Dim: {feature_dim}", + fontsize=14, + pad=15, + ) + plt.xlabel("Graph Size (Number of Edges)", fontsize=12) + plt.ylabel("Execution Time (Seconds)", fontsize=12) + plt.xscale( + "log" + ) # Using log scale for x-axis as graph sizes usually scale exponentially + plt.yscale("linear") + + plt.legend(loc="upper left", framealpha=0.9) + plt.tight_layout() + + # Output + plt.savefig(save_path) + print(f"Visualization saved to {save_path}") + if crossover_found: + print(f"Crossover point detected at approximately {int(x_cross):,} vertices.") + else: + print("No crossover point detected in the provided dataset.") + + +# Example usage assuming `payload` is already loaded in your environment: +# plot_crossover_benchmark(payload) + + +def main(): + args = parse_args() + with open(args.input, "r") as f: + payload = json.load(f) + + plot_crossover_benchmark(payload, save_path=args.output) + + +if __name__ == "__main__": + main() diff --git a/experiments/cost_model_benchmarks/visualization/plot_flop_crossover.py b/experiments/cost_model_benchmarks/visualization/plot_flop_crossover.py new file mode 100644 index 0000000..b2ea789 --- /dev/null +++ b/experiments/cost_model_benchmarks/visualization/plot_flop_crossover.py @@ -0,0 +1,140 @@ +import glob +import json +import re +import numpy as np +import matplotlib.pyplot as plt +from collections import defaultdict + + +def calculate_crossover(measurements): + """ + Calculates the exact crossover point (num_global_edges) where multi-GPU + execution becomes faster than single-GPU execution using linear interpolation. + """ + # Sort measurements strictly by graph size + measurements = sorted(measurements, key=lambda x: x["params"]["num_global_edges"]) + + vertices = [] + single_times = [] + multi_times = [] + + for m in measurements: + v = m["params"]["num_global_edges"] + + # Extract median times, ignore OOM or missing single GPU data + if m.get("single_gpu_oom", False) or not m.get("single_gpu_trials_seconds"): + continue + + s_time = np.median(m["single_gpu_trials_seconds"]) + m_time = np.median( + m["multi_gpu_trials_seconds_max"] + ) # Use max time for synchronous bottleneck + + vertices.append(v) + single_times.append(s_time) + multi_times.append(m_time) + + for i in range(1, len(vertices)): + diff_prev = multi_times[i - 1] - single_times[i - 1] + diff_curr = multi_times[i] - single_times[i] + + # Sign change indicates the lines crossed + if diff_prev * diff_curr < 0: + x1, x2 = vertices[i - 1], vertices[i] + y1_s, y2_s = single_times[i - 1], single_times[i] + y1_m, y2_m = multi_times[i - 1], multi_times[i] + + m_s = (y2_s - y1_s) / (x2 - x1) + m_m = (y2_m - y1_m) / (x2 - x1) + + if m_s != m_m: + x_cross = x1 + (y1_m - y1_s) / (m_s - m_m) + return x_cross + + return None # No crossover found in this dataset + + +def plot_crossover_dynamics( + file_pattern="results/crossover_*_world_F*.json", + save_path="results/crossover_vs_features.png", +): + """ + Parses benchmark files and plots the crossover point as a function of feature dimension, + with separate lines for each distributed world size. + """ + # Structure: data[world_size][feature_dim] = crossover_point + plot_data = defaultdict(dict) + + # Locate files + files = glob.glob(file_pattern) + if not files: + print(f"No files found matching pattern: {file_pattern}") + return + + # Regex to extract parameters from filename + pattern = re.compile(r"crossover_(\d+)_world_F(\d+)\.json") + + for filepath in files: + match = pattern.search(filepath) + if match: + world_size = int(match.group(1)) + feature_dim = int(match.group(2)) + + try: + with open(filepath, "r") as f: + payload = json.load(f) + + crossover_point = calculate_crossover(payload.get("measurements", [])) + if crossover_point is not None: + plot_data[world_size][feature_dim] = crossover_point + except Exception as e: + print(f"Error processing {filepath}: {e}") + + if not plot_data: + print("No valid crossover points could be calculated from the provided files.") + return + + # Initialize Plot + plt.figure(figsize=(10, 6), dpi=150) + plt.grid(True, which="both", ls="-", alpha=0.3) + + markers = ["o", "s", "^", "D", "v", "p"] + + # Plot lines per world size + for idx, (world_size, dim_data) in enumerate(sorted(plot_data.items())): + # Sort by feature dimension for sequential line plotting + sorted_dims = sorted(dim_data.keys()) + crossovers = [dim_data[d] for d in sorted_dims] + + plt.plot( + sorted_dims, + crossovers, + marker=markers[idx % len(markers)], + linestyle="-", + linewidth=2, + markersize=8, + label=f"{world_size} GPUs", + ) + + # Formatting + plt.title( + "GNN Communication Bottleneck:\nCrossover Threshold vs. Feature Dimension", + fontsize=14, + pad=15, + ) + plt.xlabel("Feature Dimension (F)", fontsize=12) + plt.ylabel("Crossover Point (Number of Edges)", fontsize=12) + + # Depending on the range of your feature dims, a log scale might be preferable for X + # plt.xscale("log", base=2) + plt.yscale("log") # Y-axis (vertices) usually scales exponentially + + plt.legend(title="World Size", loc="upper left", framealpha=0.9) + plt.tight_layout() + + plt.savefig(save_path) + print(f"Scaling dynamics visualization saved successfully to {save_path}") + + +if __name__ == "__main__": + plot_crossover_dynamics() diff --git a/experiments/cost_model_benchmarks/visualization/plot_gather.py b/experiments/cost_model_benchmarks/visualization/plot_gather.py new file mode 100644 index 0000000..f33b193 --- /dev/null +++ b/experiments/cost_model_benchmarks/visualization/plot_gather.py @@ -0,0 +1,221 @@ +"""Visualization — Gather / Scatter-Add Bandwidth. + +Single plot with three curves (contiguous, clustered, random) showing +gather (or scatter-add) runtime vs. k (number of rows gathered). + +Usage:: + + python -m visualization.plot_gather \\ + --contiguous data/gather_contiguous_*.json \\ + --clustered data/gather_clustered_*.json \\ + --random data/gather_random_*.json \\ + --operation gather \\ + --output figures/gather +""" + +import argparse +import json +from pathlib import Path + +import numpy as np +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +COLORS = { + "contiguous": "#1f77b4", + "clustered": "#2ca02c", + "random": "#d62728", +} +plt.rcParams.update( + { + "font.size": 9, + "axes.labelsize": 9, + "legend.fontsize": 8, + "xtick.labelsize": 8, + "ytick.labelsize": 8, + "figure.dpi": 300, + "text.usetex": False, + } +) + + +def fitted_gather( + min_val, + max_val, + feature_dim, + hbm_bandwidth_bytes_per_sec, + l2_bandwidth_bytes_per_sec, + launch_overhead_seconds, + L2_thresh, + HBM_thresh, +): + """Return a function that models gather time as a function of k.""" + + def time_model(b, overhead, inv_bw_L2, inv_bw_HBM, L2_thresh, HBM_thresh): + # 1. Bucket the bytes into their respective physical regimes + # Bytes processed exclusively at L2 speeds + bytes_L2 = np.clip(b, 0, L2_thresh) + # Bytes processed exclusively at HBM speeds + bytes_HBM = np.maximum(0, b - HBM_thresh) + + # 2. Apply the specific bandwidth (slope) to each bucket + t_mem = (bytes_L2 * inv_bw_L2) + (bytes_HBM * inv_bw_HBM) + + # 3. Floor the total time by the kernel launch overhead + return np.maximum(overhead, t_mem) + + x = np.linspace(min_val, max_val, 100) + inv_bw_L2 = 1.0 / l2_bandwidth_bytes_per_sec + inv_bw_HBM = 1.0 / hbm_bandwidth_bytes_per_sec + y = ( + time_model( + x * feature_dim * 4.0, + launch_overhead_seconds, + inv_bw_L2, + inv_bw_HBM, + L2_thresh, + HBM_thresh, + ) + * 1e3 + ) + return x, y + + +def load_gather_file(paths: list, timing_key: str): + """Merge multiple JSON files, return sorted (k, median, q25, q75) arrays.""" + rows = [] + for p in paths: + with open(p) as f: + data = json.load(f) + for meas in data["measurements"]: + trials = np.array(meas[timing_key]) + rows.append( + ( + meas["params"]["k"], + float(np.median(trials)), + float(np.percentile(trials, 25)), + float(np.percentile(trials, 75)), + ) + ) + rows.sort(key=lambda r: r[0]) + k_arr = np.array([r[0] for r in rows]) + med_arr = np.array([r[1] for r in rows]) + q25_arr = np.array([r[2] for r in rows]) + q75_arr = np.array([r[3] for r in rows]) + return k_arr, med_arr, q25_arr, q75_arr + + +def parse_args(): + p = argparse.ArgumentParser(description="Plot gather/scatter-add bandwidth") + p.add_argument("--contiguous", nargs="+", default=[], metavar="FILE") + p.add_argument("--clustered", nargs="+", default=[], metavar="FILE") + p.add_argument("--random", nargs="+", default=[], metavar="FILE") + p.add_argument("--operation", choices=["gather", "scatter_add"], default="gather") + p.add_argument("--fitted", type=str, default=None, metavar="FILE") + p.add_argument("--output", type=str, default="figures/gather") + return p.parse_args() + + +def main(): + args = parse_args() + timing_key = ( + "gather_trials_seconds" + if args.operation == "gather" + else "scatter_add_trials_seconds" + ) + + fig, ax = plt.subplots(figsize=(5, 3.5)) + + if args.fitted: + with open(args.fitted) as f: + primitives = json.load(f) + min_k, max_k = 1e3, 1e9 + for dist_name, files in [ + ("contiguous", args.contiguous), + ("clustered", args.clustered), + ("random", args.random), + ]: + if not files: + continue + k, med, q25, q75 = load_gather_file(files, timing_key) + min_k = k[0] + max_k = k[-1] + color = COLORS[dist_name] + ax.errorbar( + k * 1e-6, + med * 1e3, + yerr=[(med - q25) * 1e3, (q75 - med) * 1e3], + fmt="o", + markersize=3, + color=color, + linewidth=0.9, + capsize=2, + elinewidth=0.8, + label=dist_name.capitalize(), + ) + + if args.fitted: + hbm_bw = primitives["gather"][dist_name]["gather"][ + "bandwidth_bytes_per_sec" + ] + l2_bw = primitives["gather"][dist_name]["gather"][ + "L2_bandwidth_bytes_per_sec" + ] + overhead = primitives["gather"][dist_name]["gather"][ + "launch_overhead_seconds" + ] + thresh = primitives["gather"][dist_name]["gather"]["L2_inflection_bytes"] + hbm_thresh = primitives["gather"][dist_name]["gather"][ + "HBM_inflection_bytes" + ] + x, y = fitted_gather( + min_k, + max_k, + feature_dim=512, + hbm_bandwidth_bytes_per_sec=hbm_bw, + l2_bandwidth_bytes_per_sec=l2_bw, + launch_overhead_seconds=overhead, + L2_thresh=thresh, + HBM_thresh=hbm_thresh, + ) + ax.plot( + x * 1e-6, + y, + "--", + color=color, + linewidth=1.2, + label=f"{dist_name.capitalize()}-Expected", + alpha=0.4, + ) + ax.set_xscale("log") + ax.set_yscale("log") + op_label = ( + "Gather $x[\\mathrm{idx}]$" + if args.operation == "gather" + else "Scatter-add (backward)" + ) + ax.set_xlabel("k (millions of rows gathered)") + ax.set_ylabel(f"{op_label} time (ms)") + ax.set_title(f"Buffer-Copy Bandwidth: {op_label}", fontsize=9) + ax.legend() + ax.grid(True, which="both", linestyle=":", linewidth=0.4) + fig.tight_layout() + + out = Path(args.output) + out.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(str(out) + ".pdf", bbox_inches="tight") + fig.savefig(str(out) + ".png", bbox_inches="tight", dpi=300) + print(f"[plot_gather] Saved {out}.pdf and {out}.png") + print( + f"Caption: {op_label} time vs. gather size $k$ for three index " + "distributions: contiguous (best case, cache-friendly), clustered " + "(METIS-partitioned halo pattern), and random (worst case). " + "Error bars span IQR. The gap between contiguous/clustered and random " + "quantifies the cache-miss penalty relevant to poorly-partitioned graphs." + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/cost_model_benchmarks/visualization/plot_pingpong.py b/experiments/cost_model_benchmarks/visualization/plot_pingpong.py new file mode 100644 index 0000000..b7625a4 --- /dev/null +++ b/experiments/cost_model_benchmarks/visualization/plot_pingpong.py @@ -0,0 +1,147 @@ +"""Visualization — Ping-Pong Bandwidth / Latency. + +Produces a log-log plot of one-way transfer time vs. message size for intra- +and inter-node measurements, with fitted lines overlaid and a residuals inset. + +Usage:: + + python -m visualization.plot_pingpong \\ + --intra data/pingpong_intra_*.json \\ + --inter data/pingpong_inter_*.json \\ + --primitives data/fitted_primitives.json \\ + --output figures/pingpong +""" + +import argparse +import json +from pathlib import Path + +import numpy as np +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +from matplotlib.ticker import LogLocator + +# --------------------------------------------------------------------------- +# Shared style +# --------------------------------------------------------------------------- +COLORS = {"intra": "#1f77b4", "inter": "#d62728"} +plt.rcParams.update({ + "font.size": 9, + "axes.labelsize": 9, + "legend.fontsize": 8, + "xtick.labelsize": 8, + "ytick.labelsize": 8, + "figure.dpi": 300, + "text.usetex": False, +}) + + +def load_measurements(files: list) -> tuple: + """Return (byte_sizes, medians, q25, q75) from a list of JSON files.""" + byte_sizes, medians, q25s, q75s = [], [], [], [] + for p in files: + with open(p) as f: + data = json.load(f) + for meas in data["measurements"]: + trials = np.array(meas["trials_seconds"]) + byte_sizes.append(meas["params"]["message_bytes"]) + medians.append(float(np.median(trials))) + q25s.append(float(np.percentile(trials, 25))) + q75s.append(float(np.percentile(trials, 75))) + order = np.argsort(byte_sizes) + return ( + np.array(byte_sizes)[order], + np.array(medians)[order], + np.array(q25s)[order], + np.array(q75s)[order], + ) + + +def fitted_line(bytes_arr: np.ndarray, primitives: dict, mode: str) -> np.ndarray: + net = primitives.get("network", {}).get(mode, None) + if net is None: + return None + t_L = net["latency_seconds"] + B = net["bandwidth_bytes_per_sec"] + return t_L + bytes_arr / B + + +def parse_args(): + p = argparse.ArgumentParser(description="Plot ping-pong results") + p.add_argument("--intra", nargs="+", default=[], metavar="FILE") + p.add_argument("--inter", nargs="+", default=[], metavar="FILE") + p.add_argument("--primitives", type=str, default=None) + p.add_argument("--output", type=str, default="figures/pingpong") + return p.parse_args() + + +def main(): + args = parse_args() + + primitives = {} + if args.primitives: + with open(args.primitives) as f: + primitives = json.load(f) + + fig, axes = plt.subplots(1, 2, figsize=(7, 3.2)) + ax_main, ax_res = axes + + for mode, files in [("intra", args.intra), ("inter", args.inter)]: + if not files: + continue + xb, med, q25, q75 = load_measurements(files) + color = COLORS[mode] + label = "NVLink (intra)" if mode == "intra" else "InfiniBand (inter)" + + yerr_lo = med - q25 + yerr_hi = q75 - med + ax_main.errorbar( + xb * 1e-6, med * 1e3, + yerr=[yerr_lo * 1e3, yerr_hi * 1e3], + fmt="o", markersize=3, color=color, label=label, + capsize=2, linewidth=0.8, elinewidth=0.8, + ) + + fit = fitted_line(xb, primitives, mode) + if fit is not None: + ax_main.plot(xb * 1e-6, fit * 1e3, "--", color=color, + linewidth=1.2, label=f"{label} fit") + # Residuals + residuals = (med - fit) / med * 100 # percent + ax_res.plot(xb * 1e-6, residuals, "o-", markersize=3, color=color, + linewidth=0.8, label=label) + + ax_main.set_xscale("log") + ax_main.set_yscale("log") + ax_main.set_xlabel("Message size (MB)") + ax_main.set_ylabel("One-way transfer time (ms)") + ax_main.legend(loc="upper left") + ax_main.grid(True, which="both", linestyle=":", linewidth=0.4) + + ax_res.axhline(0, color="k", linewidth=0.6, linestyle="--") + ax_res.set_xscale("log") + ax_res.set_xlabel("Message size (MB)") + ax_res.set_ylabel("Residual (%)") + ax_res.legend() + ax_res.grid(True, which="both", linestyle=":", linewidth=0.4) + + fig.suptitle("Network Ping-Pong: Transfer Time vs. Message Size", fontsize=10) + fig.tight_layout() + + out = Path(args.output) + out.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(str(out) + ".pdf", bbox_inches="tight") + fig.savefig(str(out) + ".png", bbox_inches="tight", dpi=300) + print(f"[plot_pingpong] Saved {out}.pdf and {out}.png") + print( + "Caption: Log-log plot of one-way network transfer time vs. message size " + "for intra-node (NVLink) and inter-node (InfiniBand) communication. " + "Points show medians; error bars span the 25th--75th percentile (IQR). " + "Dashed lines show linear-latency fits $T = t_L + s/B$. " + "Right panel shows residuals (\\%) from the fit." + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/cost_model_benchmarks/visualization/plot_tipping_point.py b/experiments/cost_model_benchmarks/visualization/plot_tipping_point.py new file mode 100644 index 0000000..735c96d --- /dev/null +++ b/experiments/cost_model_benchmarks/visualization/plot_tipping_point.py @@ -0,0 +1,141 @@ +"""Visualization — T_global vs. K (Tipping Point). + +Shows total training throughput or per-layer time as a function of the number +of GPUs K for a fixed graph. Overlays the cost-model prediction (solid) and +measured values (dashed with markers). Annotates K* — the point beyond which +adding more GPUs yields diminishing returns. + +T_global(K) is computed as: + + T_global(K) = T_layer(K) * num_layers * num_epochs + +For the tipping-point annotation K* is the largest K where the speedup +relative to K=1 is still within 10% of linear. + +Usage:: + + python -m visualization.plot_tipping_point \\ + --predictions data/predictions.json \\ + --num-layers 3 \\ + --num-epochs 100 \\ + --graph erdos_renyi \\ + --output figures/tipping_point +""" + +import argparse +import json +from pathlib import Path + +import numpy as np +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +plt.rcParams.update({ + "font.size": 9, "axes.labelsize": 9, "legend.fontsize": 8, + "xtick.labelsize": 8, "ytick.labelsize": 8, + "figure.dpi": 300, "text.usetex": False, +}) + + +def parse_args(): + p = argparse.ArgumentParser(description="Plot T_global vs. K tipping point") + p.add_argument("--predictions", type=str, required=True) + p.add_argument("--num-layers", type=int, default=3) + p.add_argument("--num-epochs", type=int, default=100) + p.add_argument("--graph", type=str, default=None, + help="Filter to this graph type (optional)") + p.add_argument("--output", type=str, default="figures/tipping_point") + return p.parse_args() + + +def main(): + args = parse_args() + + with open(args.predictions) as f: + data = json.load(f) + + entries = data["predictions"] + + # Filter by graph type if requested + if args.graph: + entries = [e for e in entries + if e["config"].get("graph", "") == args.graph] + + # Group by world_size + ws_to_meas = {} + ws_to_pred = {} + for e in entries: + K = e["config"].get("world_size", 1) + ws_to_meas.setdefault(K, []).append(e["measured_median_seconds"]) + ws_to_pred.setdefault(K, []).append(e["predicted_seconds"]) + + if not ws_to_meas: + print("[plot_tipping_point] No data found. Exiting.") + return + + K_vals = sorted(ws_to_meas.keys()) + T_meas = np.array([np.median(ws_to_meas[K]) for K in K_vals]) * args.num_layers * args.num_epochs + T_pred = np.array([np.median(ws_to_pred[K]) for K in K_vals]) * args.num_layers * args.num_epochs + K_arr = np.array(K_vals, dtype=float) + + # Ideal linear scaling from K=1 + T_single = T_meas[0] # K=1 reference + T_ideal = T_single / K_arr + + # Identify K*: largest K where speedup is ≥ 90% of ideal + speedup = T_single / T_meas + ideal_speedup = K_arr + efficiency = speedup / ideal_speedup + kstar_idx = np.where(efficiency >= 0.9)[0] + K_star = K_arr[kstar_idx[-1]] if len(kstar_idx) else K_arr[0] + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 3.5)) + + # --- Left panel: T_global vs K --- + ax1.plot(K_arr, T_ideal, "k:", linewidth=1, label="Ideal linear scaling") + ax1.plot(K_arr, T_pred, "-", color="#1f77b4", linewidth=1.5, label="Predicted") + ax1.plot(K_arr, T_meas, "o--", color="#d62728", markersize=5, linewidth=1.2, + label="Measured") + ax1.axvline(K_star, color="gray", linestyle="--", linewidth=0.8) + ax1.text(K_star * 1.05, ax1.get_ylim()[1] * 0.95, + f"$K^* = {int(K_star)}$", fontsize=8, color="gray", va="top") + ax1.set_xlabel("Number of GPUs ($K$)") + ax1.set_ylabel(f"$T_{{\\mathrm{{global}}}}$ (s)\n" + f"({args.num_layers} layers × {args.num_epochs} epochs)") + ax1.set_title("Training Time vs. GPU Count", fontsize=9) + ax1.legend() + ax1.grid(True, linestyle=":", linewidth=0.4) + + # --- Right panel: scaling efficiency --- + ax2.plot(K_arr, efficiency * 100, "o-", color="#2ca02c", markersize=5, linewidth=1.2, + label="Scaling efficiency") + ax2.axhline(90, color="gray", linestyle="--", linewidth=0.8, label="90% threshold") + ax2.axvline(K_star, color="gray", linestyle="--", linewidth=0.8) + ax2.set_xlabel("Number of GPUs ($K$)") + ax2.set_ylabel("Scaling efficiency (%)") + ax2.set_title("Strong Scaling Efficiency", fontsize=9) + ax2.set_ylim(0, 110) + ax2.legend() + ax2.grid(True, linestyle=":", linewidth=0.4) + + fig.suptitle("Tipping-Point Analysis: When Does Adding GPUs Stop Helping?", fontsize=10) + fig.tight_layout() + + out = Path(args.output) + out.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(str(out) + ".pdf", bbox_inches="tight") + fig.savefig(str(out) + ".png", bbox_inches="tight", dpi=300) + print(f"[plot_tipping_point] Saved {out}.pdf and {out}.png K*={int(K_star)}") + print( + f"Caption: (Left) Total training time $T_{{\\mathrm{{global}}}}$ vs. GPU count $K$ " + f"for {args.num_layers}-layer GNN trained for {args.num_epochs} epochs. " + "Solid blue: cost-model prediction. Red dashed: measured. " + "Dotted black: ideal linear speedup. " + f"$K^* = {int(K_star)}$ marks the last GPU count with $\\geq 90\\%$ scaling efficiency. " + "(Right) Scaling efficiency $= \\text{{speedup}} / K$." + ) + + +if __name__ == "__main__": + main() diff --git a/experiments/cost_model_benchmarks/visualization/plot_validation.py b/experiments/cost_model_benchmarks/visualization/plot_validation.py new file mode 100644 index 0000000..65a6497 --- /dev/null +++ b/experiments/cost_model_benchmarks/visualization/plot_validation.py @@ -0,0 +1,127 @@ +"""Visualization — Predicted vs. Measured Scatter (Headline Figure). + +Reads ``data/predictions.json`` and produces a predicted-vs-measured scatter +plot. Points are colored by graph type (or fit/held-out split). + +Usage:: + + python -m visualization.plot_validation \\ + --predictions data/predictions.json \\ + --color-by split \\ + --output figures/validation +""" + +import argparse +import json +from pathlib import Path + +import numpy as np +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +plt.rcParams.update({ + "font.size": 9, "axes.labelsize": 9, "legend.fontsize": 8, + "xtick.labelsize": 8, "ytick.labelsize": 8, + "figure.dpi": 300, "text.usetex": False, +}) + + +def parse_args(): + p = argparse.ArgumentParser(description="Plot predicted vs. measured scatter") + p.add_argument("--predictions", type=str, required=True) + p.add_argument("--color-by", choices=["split", "graph", "world_size"], + default="split") + p.add_argument("--output", type=str, default="figures/validation") + return p.parse_args() + + +def main(): + args = parse_args() + + with open(args.predictions) as f: + data = json.load(f) + + entries = data["predictions"] + mape_fit = data["aggregate"]["mape_fit_set"] + mape_held = data["aggregate"]["mape_held_out"] + + meas = np.array([e["measured_median_seconds"] * 1e3 for e in entries]) + pred = np.array([e["predicted_seconds"] * 1e3 for e in entries]) + + # Color groups + if args.color_by == "split": + groups = { + "Fit set": [i for i, e in enumerate(entries) if e["in_fit_set"]], + "Held-out": [i for i, e in enumerate(entries) if not e["in_fit_set"]], + } + palette = {"Fit set": "#1f77b4", "Held-out": "#d62728"} + elif args.color_by == "graph": + graph_types = sorted(set(e["config"].get("graph", "unknown") for e in entries)) + palette = dict(zip(graph_types, ["#1f77b4", "#ff7f0e", "#2ca02c", "#9467bd"])) + groups = {g: [i for i, e in enumerate(entries) + if e["config"].get("graph", "unknown") == g] + for g in graph_types} + else: # world_size + ws_vals = sorted(set(e["config"].get("world_size", 1) for e in entries)) + colors = plt.cm.viridis(np.linspace(0, 1, len(ws_vals))) + palette = {f"K={w}": c for w, c in zip(ws_vals, colors)} + groups = {f"K={w}": [i for i, e in enumerate(entries) + if e["config"].get("world_size", 1) == w] + for w in ws_vals} + + fig, ax = plt.subplots(figsize=(4.5, 4.5)) + + all_vals = np.concatenate([meas, pred]) + lo, hi = all_vals.min() * 0.9, all_vals.max() * 1.1 + ax.plot([lo, hi], [lo, hi], "k--", linewidth=0.8, label="Ideal (y=x)") + + # 10% error bands + ax.fill_between([lo, hi], [lo * 0.9, hi * 0.9], [lo * 1.1, hi * 1.1], + alpha=0.08, color="gray") + + for label, idxs in groups.items(): + if not idxs: + continue + color = palette[label] + ax.scatter(meas[idxs], pred[idxs], s=22, color=color, + alpha=0.85, label=label, zorder=3) + + ax.set_xlim(lo, hi) + ax.set_ylim(lo, hi) + ax.set_xlabel("Measured $T_{\\mathrm{layer}}$ (ms)") + ax.set_ylabel("Predicted $T_{\\mathrm{layer}}$ (ms)") + ax.set_title("Cost Model Validation: Predicted vs. Measured", fontsize=9) + + # Annotate MAPE + mape_text = ( + f"Fit MAPE = {mape_fit*100:.1f}%\n" + f"Held-out MAPE = {mape_held*100:.1f}%" + ) + ax.text(0.04, 0.96, mape_text, transform=ax.transAxes, + verticalalignment="top", fontsize=8, + bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8)) + + ax.legend(loc="lower right", fontsize=8) + ax.grid(True, linestyle=":", linewidth=0.4) + fig.tight_layout() + + out = Path(args.output) + out.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(str(out) + ".pdf", bbox_inches="tight") + fig.savefig(str(out) + ".png", bbox_inches="tight", dpi=300) + print(f"[plot_validation] Saved {out}.pdf and {out}.png") + print( + "Caption: Predicted vs. measured layer time $T_{\\mathrm{layer}}$ for all " + "benchmarked configurations. Dashed diagonal: perfect prediction. " + "Shaded band: $\\pm10\\%$ error region. " + f"In-sample MAPE = {mape_fit*100:.1f}\\%, " + f"held-out MAPE = {mape_held*100:.1f}\\%. " + "Colors distinguish " + + ("fit-set vs. held-out configurations." if args.color_by == "split" + else f"configurations by {args.color_by}.") + ) + + +if __name__ == "__main__": + main()