Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions DGraph/distributed/commInfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,22 @@ def compute_halo_vertices(
"""
Computes halo vertices. Supports both homogeneous and bipartite/heterogeneous relations.
"""
# Fallback for homogeneous graphs
if dst_partitioning is None:
dst_partitioning = src_partitioning

src_rank = src_partitioning[edge_list[:, 0]]
dst_rank = dst_partitioning[edge_list[:, 1]]

# Cross-rank mask: source is local, destination is remote
cross_mask = (src_rank == rank) & (dst_rank != rank)
# FIX: Cross-rank mask for pull model: source is remote, destination is local
cross_mask = (src_rank != rank) & (dst_rank == rank)

# Return unique destination vertex IDs from those edges
return torch.unique(edge_list[cross_mask, 1])
# FIX: Return unique SOURCE vertex IDs from those edges
return torch.unique(edge_list[cross_mask, 0])


def compute_local_edge_list(
global_edge_list: torch.Tensor, # [E, 2]
partitioning: torch.Tensor, # [V]
partitioning: torch.Tensor, # [V] (Acts as dst_partitioning)
local_vertices_global: torch.Tensor, # [num_local]
halo_vertices_global: torch.Tensor, # [num_halo]
rank: int,
Expand All @@ -72,8 +71,8 @@ def compute_local_edge_list(
num_halo = halo_vertices_global.size(0)
num_global = partitioning.size(0)

# Filter edges owned by this rank
local_edge_mask = partitioning[global_edge_list[:, 0]] == rank
# FIX: Filter edges where the DESTINATION is owned by this rank (Index 1)
local_edge_mask = partitioning[global_edge_list[:, 1]] == rank
local_edges_global = global_edge_list[local_edge_mask]

# Build inverse map: global_id -> local_idx via scatter
Expand Down
115 changes: 75 additions & 40 deletions experiments/GraphCast/data_utils/graphcast_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,29 +57,29 @@ class GraphCastTopology:

@dataclass
class DistributedGraphCastGraph:
# Distributed environment info
rank: int
world_size: int
ranks_per_graph: int

# Graph metadata
mesh_level: int
lat_lon_grid: Tensor

# Mesh vertex features
mesh_graph_node_features: Tensor
mesh_graph_edge_features: Tensor
mesh_graph_node_rank_placement: Tensor
mesh_graph_edge_rank_placement: Tensor
mesh_graph_src_indices: Tensor
mesh_graph_dst_indices: Tensor
mesh_graph_src_rank_placement: Tensor
mesh_graph_dst_rank_placement: Tensor
grid_rank_placement: Tensor

# Grid vertex features
mesh2grid_graph_node_features: Tensor
mesh2grid_graph_edge_features: Tensor
mesh2grid_graph_edge_rank_placement: Tensor
mesh2grid_graph_src_indices: Tensor
mesh2grid_graph_dst_indices: Tensor
grid2mesh_graph_node_features: Tensor

# Mesh <--> Grid edge features
mesh2grid_graph_edge_features: Tensor
grid2mesh_graph_edge_features: Tensor
grid2mesh_graph_src_indices: Tensor
grid2mesh_graph_dst_indices: Tensor

# Distributed graph info
distributed_comm_patterns: GraphCastCommPatterns


def build_graphcast_comm_patterns(graph: GraphCastTopology) -> GraphCastCommPatterns:
Expand Down Expand Up @@ -110,7 +110,6 @@ def build_graphcast_comm_patterns(graph: GraphCastTopology) -> GraphCastCommPatt
grid2mesh_cp = build_communication_pattern(
global_edge_list=grid2mesh_edges,
partitioning=mesh_part,
neighbor_partitioning=grid_part,
rank=rank,
world_size=world_size,
)
Expand All @@ -129,7 +128,6 @@ def build_graphcast_comm_patterns(graph: GraphCastTopology) -> GraphCastCommPatt
mesh_cp = build_communication_pattern(
global_edge_list=mesh_edges,
partitioning=mesh_part,
neighbor_partitioning=mesh_part,
rank=rank,
world_size=world_size,
)
Expand All @@ -147,7 +145,6 @@ def build_graphcast_comm_patterns(graph: GraphCastTopology) -> GraphCastCommPatt
mesh2grid_cp = build_communication_pattern(
global_edge_list=mesh2grid_edges,
partitioning=grid_part,
neighbor_partitioning=mesh_part,
rank=rank,
world_size=world_size,
)
Expand Down Expand Up @@ -220,6 +217,47 @@ def get_mesh_graph_partition(mesh_level: int, world_size: int):
mesh_vertex_rank_placement = torch.tensor(mesh_vertex_rank_placement)
return mesh_vertex_rank_placement

@staticmethod
def get_grid_vertex_partition(
lat: int,
lon: int,
mesh_vertex_rank_placement: torch.Tensor,
grid2mesh_grid_src_indices: torch.Tensor,
grid2mesh_mesh_dst_indices: torch.Tensor,
mesh2grid_mesh_src_indices: torch.Tensor,
world_size: int,
) -> torch.Tensor:
"""Generate the partitioning of grid vertices to minimize cross-rank edges.

For each grid vertex, counts how many of its connected mesh vertices
(via both grid2mesh and mesh2grid edges) live on each rank, then assigns
the grid vertex to the rank with the plurality of connections.

mesh2grid grid destinations are implicit: grid vertex i owns edges
[3i, 3i+1, 3i+2] since create_mesh2grid_graph assigns exactly 3
edges (one face's vertices) per grid vertex.
"""
num_grid = lat * lon
votes = torch.zeros(num_grid, world_size, dtype=torch.long)

# --- grid2mesh contribution: grid vertex is src, mesh vertex is dst ---
g2m_ranks = mesh_vertex_rank_placement[grid2mesh_mesh_dst_indices.long()]
# Flatten (grid_vertex, rank) into a 1D index for scatter_add_
g2m_flat_idx = grid2mesh_grid_src_indices.long() * world_size + g2m_ranks
votes.view(-1).scatter_add_(0, g2m_flat_idx, torch.ones_like(g2m_flat_idx))

# --- mesh2grid contribution: mesh vertex is src, grid vertex is dst ---
# Each grid vertex i has exactly 3 mesh2grid edges at positions [3i, 3i+1, 3i+2]
m2g_grid_dst = torch.arange(num_grid, dtype=torch.long).repeat_interleave(3)
m2g_ranks = mesh_vertex_rank_placement[mesh2grid_mesh_src_indices.long()]
m2g_flat_idx = m2g_grid_dst * world_size + m2g_ranks
votes.view(-1).scatter_add_(0, m2g_flat_idx, torch.ones_like(m2g_flat_idx))

# Assign each grid vertex to the rank with the most connections
grid_partitioning = votes.argmax(dim=1)

return grid_partitioning

def get_mesh_graph(self, mesh_vertex_rank_placement: torch.Tensor):
"""Get the graph for the distributed graphcast graph."""

Expand Down Expand Up @@ -327,8 +365,6 @@ def get_grid2mesh_graph(self, mesh_graph_dict: dict):
grid_vertex_rank_placement
)

# TODO: Consider we can have it to so grid2mesh edges don't require a
# backpropagation. (If encoder is after the gather / scatter)
grid2mesh_graph_dict = {
"node_features": torch.tensor([]),
"edge_features": edge_features,
Expand All @@ -340,7 +376,7 @@ def get_grid2mesh_graph(self, mesh_graph_dict: dict):
}
return grid2mesh_graph_dict

def get_mesh2grid_graph(
def get_mesh2grid_edges(
self,
grid_vertex_rank_placement,
renumbered_vertices,
Expand All @@ -359,7 +395,6 @@ def get_mesh2grid_graph(
mesh2grid_edge_rank_placement = grid_vertex_rank_placement[dst_grid_indices]

mesh2grid_graph_dict = {
"node_features": torch.tensor([]),
"edge_features": edge_features,
"src_indices": src_mesh_indices,
"dst_indices": dst_grid_indices,
Expand Down Expand Up @@ -401,10 +436,26 @@ def get_graphcast_graph(
grid_vertex_rank_placement = grid2mesh_graph["grid_vertex_rank_placement"]
renumbered_grid = grid2mesh_graph["renumbered_grid"]

mesh2grid_graph = self.get_mesh2grid_graph(
mesh2grid_graph = self.get_mesh2grid_edges(
grid_vertex_rank_placement, renumbered_vertices, renumbered_grid
)

topology = GraphCastTopology(
rank=self.local_rank,
world_size=self.world_size,
ranks_per_graph=self.ranks_per_graph,
mesh_rank_placement=mesh_vertex_rank_placement,
grid_rank_placement=grid_vertex_rank_placement,
mesh_graph_src_indices=mesh_graph["src_indices"],
mesh_graph_dst_indices=mesh_graph["dst_indices"],
mesh2grid_graph_src_indices=mesh2grid_graph["src_indices"],
mesh2grid_graph_dst_indices=mesh2grid_graph["dst_indices"],
grid2mesh_graph_src_indices=grid2mesh_graph["src_indices"],
grid2mesh_graph_dst_indices=grid2mesh_graph["dst_indices"],
)

comm_patterns = build_graphcast_comm_patterns(topology)

return DistributedGraphCastGraph(
rank=self.rank,
world_size=self.world_size,
Expand All @@ -413,25 +464,9 @@ def get_graphcast_graph(
lat_lon_grid=self.lat_lon_grid,
mesh_graph_node_features=mesh_graph["node_features"],
mesh_graph_edge_features=mesh_graph["edge_features"],
mesh_graph_node_rank_placement=mesh_graph["node_rank_placement"],
mesh_graph_edge_rank_placement=mesh_graph["edge_rank_placement"],
mesh_graph_src_indices=mesh_graph["src_indices"],
mesh_graph_dst_indices=mesh_graph["dst_indices"],
mesh_graph_src_rank_placement=mesh_graph["src_rank_placement"],
mesh_graph_dst_rank_placement=mesh_graph["dst_rank_placement"],
grid_rank_placement=grid2mesh_graph["grid_vertex_rank_placement"],
mesh2grid_graph_node_features=mesh2grid_graph["node_features"],
mesh2grid_graph_edge_features=mesh2grid_graph["edge_features"],
mesh2grid_graph_edge_rank_placement=mesh2grid_graph[
"mesh2grid_edge_rank_placement"
],
mesh2grid_graph_src_indices=mesh2grid_graph["src_indices"],
mesh2grid_graph_dst_indices=mesh2grid_graph["dst_indices"],
mesh2grid_graph_node_features=torch.tensor([]),
grid2mesh_graph_node_features=grid2mesh_graph["node_features"],
mesh2grid_graph_edge_features=mesh2grid_graph["edge_features"],
grid2mesh_graph_edge_features=grid2mesh_graph["edge_features"],
grid2mesh_graph_edge_rank_placement=grid2mesh_graph[
"grid2mesh_edge_rank_placement"
],
grid2mesh_graph_src_indices=grid2mesh_graph["src_indices"],
grid2mesh_graph_dst_indices=grid2mesh_graph["dst_indices"],
distributed_comm_patterns=comm_patterns,
)
6 changes: 0 additions & 6 deletions experiments/GraphCast/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,20 +261,14 @@ def test_synthetic_weather_dataset(num_days, batch_size=1):
print("Mesh label:\t", static_graph.mesh_level)
print("Mesh Node features:\t", static_graph.mesh_graph_node_features.shape)
print("Mesh Edge features:\t", static_graph.mesh_graph_edge_features.shape)
print("Mesh src indices:\t", static_graph.mesh_graph_src_indices.shape)
print("Mesh dst indices:\t", static_graph.mesh_graph_dst_indices.shape)
print("=" * 80)
print(
"mesh2grid edge features:\t", static_graph.mesh2grid_graph_edge_features.shape
)
print("mesh2grid src indices:\t", static_graph.mesh2grid_graph_src_indices.shape)
print("mesh2grid dst indices:\t", static_graph.mesh2grid_graph_dst_indices.shape)
print("=" * 80)
print(
"grid2mesh edge features:\t", static_graph.grid2mesh_graph_edge_features.shape
)
print("grid2mesh src indices:\t", static_graph.grid2mesh_graph_src_indices.shape)
print("grid2mesh dst indices:\t", static_graph.grid2mesh_graph_dst_indices.shape)
print("=" * 80)


Expand Down
Loading