1use std::collections::BTreeSet;
4
5use itertools::Itertools;
6use slotmap::{SecondaryMap, SparseSecondaryMap};
7
8use super::meta_graph::DfirGraph;
9use super::ops::{DelayType, FloType};
10use super::{Color, GraphEdgeId, GraphNode, GraphNodeId, HandoffKind};
11use crate::diagnostic::{Diagnostic, Level};
12use crate::graph::graph_algorithms::SubgraphMerge;
13
14fn find_edge_barriers(
21 partitioned_graph: &DfirGraph,
22) -> (
23 SecondaryMap<GraphEdgeId, DelayType>,
24 Vec<(GraphNodeId, GraphNodeId)>,
25) {
26 let mut tick_edges = SecondaryMap::new();
27 let mut barrier_pairs = Vec::new();
28
29 for (edge_id, (src, dst)) in partitioned_graph.edges() {
30 if partitioned_graph.node_loop(dst).is_some() {
32 continue;
33 }
34 let Some(op_inst) = partitioned_graph.node_op_inst(dst) else {
35 continue;
36 };
37 let (_src_port, dst_port) = partitioned_graph.edge_ports(edge_id);
38 let Some(delay_type) = (op_inst.op_constraints.input_delaytype_fn)(dst_port) else {
39 continue;
40 };
41
42 barrier_pairs.push((src, dst));
43 if matches!(delay_type, DelayType::Tick | DelayType::TickLazy) {
44 tick_edges.insert(edge_id, delay_type);
45 }
46 }
47
48 (tick_edges, barrier_pairs)
49}
50
51fn find_access_group_ordering(partitioned_graph: &DfirGraph) -> Vec<(GraphNodeId, GraphNodeId)> {
54 let mut pairs = Vec::new();
55 let refs_by_target = partitioned_graph.node_handoff_reference_groups();
56 for (_handoff, groups) in refs_by_target {
57 for (group_a, group_b) in groups.values().tuple_windows() {
58 for &(node_a, _, _) in group_a {
59 for &(node_b, _, _) in group_b {
60 assert_ne!(
62 node_a, node_b,
63 "encounted conflicted or cyclical handoff references\n{:?}\n{:?}",
64 group_a, group_b,
65 );
66 pairs.push((node_a, node_b));
67 }
68 }
69 }
70 }
71 pairs
72}
73
74fn find_subgraph_unionfind(
75 partitioned_graph: &DfirGraph,
76 tick_edges: &SecondaryMap<GraphEdgeId, DelayType>,
77 edge_barrier_pairs: &[(GraphNodeId, GraphNodeId)],
78 access_group_pairs: &[(GraphNodeId, GraphNodeId)],
79) -> Result<(SubgraphMerge<GraphNodeId>, BTreeSet<GraphEdgeId>), Diagnostic> {
80 let mut node_color = partitioned_graph
85 .node_ids()
86 .filter_map(|node_id| {
87 let op_color = partitioned_graph.node_color(node_id)?;
88 Some((node_id, op_color))
89 })
90 .collect::<SparseSecondaryMap<_, _>>();
91
92 let mut all_preds: SecondaryMap<GraphNodeId, Vec<GraphNodeId>> = SecondaryMap::new();
94
95 for (edge_id, (src, dst)) in partitioned_graph.edges() {
97 if !tick_edges.contains_key(edge_id) {
98 all_preds.entry(dst).unwrap().or_default().push(src);
99 }
100 }
101
102 for node_id in partitioned_graph.node_ids() {
104 for handoff_ref in partitioned_graph.node_handoff_references(node_id).iter() {
105 if let Some(src) = handoff_ref.node_id {
106 all_preds.entry(node_id).unwrap().or_default().push(src);
107 if let GraphNode::Handoff { .. } = partitioned_graph.node(src) {
110 for (_edge, consumer) in partitioned_graph.node_successors(src) {
111 all_preds
112 .entry(consumer)
113 .unwrap()
114 .or_default()
115 .push(node_id);
116 }
117 }
118 }
119 }
120 }
121
122 for &(src, dst) in access_group_pairs {
124 all_preds.entry(dst).unwrap().or_default().push(src);
125 }
126
127 let enemies = edge_barrier_pairs
129 .iter()
130 .copied()
131 .chain(access_group_pairs.iter().copied())
132 .chain(partitioned_graph.node_ids().flat_map(|dst| {
133 partitioned_graph
134 .node_handoff_references(dst)
135 .iter()
136 .filter_map(|r| r.node_id)
137 .map(move |src| (src, dst))
138 }));
139
140 let mut subgraph_unionfind = SubgraphMerge::<GraphNodeId>::new(
141 partitioned_graph.node_ids(),
142 |node_id| all_preds.get(node_id).into_iter().flatten().copied(),
143 enemies,
144 )
145 .map_err(|cycle| {
146 let span = cycle
147 .first()
148 .map(|&node_id| partitioned_graph.node(node_id).span())
149 .unwrap_or_else(proc_macro2::Span::call_site);
150 let node_cycle = cycle
151 .iter()
152 .map(|&node_id| partitioned_graph.node(node_id).to_pretty_string())
153 .collect::<Vec<_>>();
154 Diagnostic::spanned(
155 span,
156 Level::Error,
157 format!(
158 "Cyclical dataflow within a tick is not supported. Use `defer_tick()` or `defer_tick_lazy()` to break the cycle across ticks. \
159 Cycle: {:?}",
160 node_cycle,
161 ),
162 )
163 })?;
164
165 let mut handoff_edges: BTreeSet<GraphEdgeId> = partitioned_graph.edge_ids().collect();
168 let mut progress = true;
177 while progress {
178 progress = false;
179 for (edge_id, (src, dst)) in partitioned_graph.edges().collect::<Vec<_>>() {
181 if matches!(partitioned_graph.node(src), GraphNode::Handoff { .. })
183 || matches!(partitioned_graph.node(dst), GraphNode::Handoff { .. })
184 {
185 handoff_edges.remove(&edge_id);
186 continue;
187 }
188
189 if subgraph_unionfind.same_set(src, dst) {
191 continue;
194 }
195
196 if partitioned_graph.node_loop(src) != partitioned_graph.node_loop(dst) {
198 continue;
199 }
200 if partitioned_graph.node_op_inst(dst).is_some_and(|op_inst| {
202 Some(FloType::NextIteration) == op_inst.op_constraints.flo_type
203 }) {
204 continue;
205 }
206
207 if can_connect_colorize(&mut node_color, src, dst) {
208 let ok = subgraph_unionfind.try_merge(src, dst);
211 if ok {
212 assert!(handoff_edges.remove(&edge_id));
213 progress = true;
214 }
215 }
216 }
217 }
218
219 Ok((subgraph_unionfind, handoff_edges))
220}
221
222fn make_subgraphs(
224 partitioned_graph: &mut DfirGraph,
225 tick_edges: &mut SecondaryMap<GraphEdgeId, DelayType>,
226 edge_barrier_pairs: &[(GraphNodeId, GraphNodeId)],
227 access_group_pairs: &[(GraphNodeId, GraphNodeId)],
228) -> Result<(), Diagnostic> {
229 let (subgraph_merge, handoff_edges) = find_subgraph_unionfind(
238 partitioned_graph,
239 tick_edges,
240 edge_barrier_pairs,
241 access_group_pairs,
242 )?;
243
244 for edge_id in handoff_edges {
246 let (src_id, dst_id) = partitioned_graph.edge(edge_id);
247
248 let src_node = partitioned_graph.node(src_id);
250 let dst_node = partitioned_graph.node(dst_id);
251 if matches!(src_node, GraphNode::Handoff { .. })
252 || matches!(dst_node, GraphNode::Handoff { .. })
253 {
254 continue;
255 }
256
257 let hoff = GraphNode::Handoff {
258 kind: HandoffKind::Vec,
259 src_span: src_node.span(),
260 dst_span: dst_node.span(),
261 };
262 let (_node_id, out_edge_id) = partitioned_graph.insert_intermediate_node(edge_id, hoff);
263
264 if let Some(delay_type) = tick_edges.remove(edge_id) {
266 tick_edges.insert(out_edge_id, delay_type);
267 }
268 }
269
270 let mut subgraph_toposort = Vec::new();
273 for nodes in subgraph_merge.subgraphs() {
274 if nodes.is_empty() {
275 continue;
276 }
277 if nodes
279 .iter()
280 .any(|&n| matches!(partitioned_graph.node(n), GraphNode::Handoff { .. }))
281 {
282 continue;
283 }
284 let sg_id = partitioned_graph.insert_subgraph(nodes.to_vec()).unwrap();
285 subgraph_toposort.push(sg_id);
286 }
287 partitioned_graph.set_subgraph_toposort(subgraph_toposort);
288 Ok(())
289}
290
291fn can_connect_colorize(
297 node_color: &mut SparseSecondaryMap<GraphNodeId, Color>,
298 src: GraphNodeId,
299 dst: GraphNodeId,
300) -> bool {
301 let can_connect = match (node_color.get(src), node_color.get(dst)) {
306 (None, None) => false,
309
310 (None, Some(Color::Pull | Color::Comp)) => {
312 node_color.insert(src, Color::Pull);
313 true
314 }
315 (None, Some(Color::Push | Color::Hoff)) => {
316 node_color.insert(src, Color::Push);
317 true
318 }
319
320 (Some(Color::Pull | Color::Hoff), None) => {
322 node_color.insert(dst, Color::Pull);
323 true
324 }
325 (Some(Color::Comp | Color::Push), None) => {
326 node_color.insert(dst, Color::Push);
327 true
328 }
329
330 (Some(Color::Pull), Some(Color::Pull)) => true,
332 (Some(Color::Pull), Some(Color::Comp)) => true,
333 (Some(Color::Pull), Some(Color::Push)) => true,
334
335 (Some(Color::Comp), Some(Color::Pull)) => false,
336 (Some(Color::Comp), Some(Color::Comp)) => false,
337 (Some(Color::Comp), Some(Color::Push)) => true,
338
339 (Some(Color::Push), Some(Color::Pull)) => false,
340 (Some(Color::Push), Some(Color::Comp)) => false,
341 (Some(Color::Push), Some(Color::Push)) => true,
342
343 (Some(Color::Hoff), Some(_)) => false,
345 (Some(_), Some(Color::Hoff)) => false,
346 };
347 can_connect
348}
349
350fn mark_tick_boundary_handoffs(
353 partitioned_graph: &mut DfirGraph,
354 tick_edges: &SecondaryMap<GraphEdgeId, DelayType>,
355) {
356 let tick_handoffs: Vec<_> = partitioned_graph
357 .nodes()
358 .filter_map(|(hoff_id, hoff)| {
359 if !matches!(hoff, GraphNode::Handoff { .. }) {
360 return None;
361 }
362 if partitioned_graph.node_degree_out(hoff_id) == 0 {
363 return None;
364 }
365 let (succ_edge, _) = partitioned_graph.node_successors(hoff_id).next().unwrap();
366 let &delay_type = tick_edges.get(succ_edge)?;
367 Some((hoff_id, delay_type))
368 })
369 .collect();
370
371 for (hoff_id, delay_type) in tick_handoffs {
372 partitioned_graph.set_handoff_delay_type(hoff_id, delay_type);
373 }
374}
375
376pub fn partition_graph(flat_graph: DfirGraph) -> Result<DfirGraph, Diagnostic> {
380 let (mut tick_edges, edge_barrier_pairs) = find_edge_barriers(&flat_graph);
381 let access_group_pairs = find_access_group_ordering(&flat_graph);
382 let mut partitioned_graph = flat_graph;
383
384 make_subgraphs(
386 &mut partitioned_graph,
387 &mut tick_edges,
388 &edge_barrier_pairs,
389 &access_group_pairs,
390 )?;
391
392 mark_tick_boundary_handoffs(&mut partitioned_graph, &tick_edges);
394
395 Ok(partitioned_graph)
396}