icb_parser/
cpp_tree_sitter.rs1use icb_common::{IcbError, Language, NodeKind};
42use tree_sitter::{Node, Parser};
43
44use crate::facts::RawNode;
45
46pub fn parse_cpp_file(source: &str) -> Result<Vec<RawNode>, IcbError> {
48 let mut parser = Parser::new();
49 parser
50 .set_language(&tree_sitter_cpp::language())
51 .map_err(|e| IcbError::Parse(format!("cannot set tree-sitter-cpp language: {e}")))?;
52
53 let tree = parser
54 .parse(source, None)
55 .ok_or_else(|| IcbError::Parse("tree-sitter parse returned None".into()))?;
56
57 let mut facts = Vec::new();
58 traverse_node(tree.root_node(), source, &mut facts, None);
59 Ok(facts)
60}
61
62fn traverse_node(
63 node: Node,
64 source: &str,
65 facts: &mut Vec<RawNode>,
66 parent_idx: Option<usize>,
67) -> Option<usize> {
68 let kind = node.kind();
69
70 let (node_kind, name, is_container) = match kind {
71 "function_definition" | "function_declaration" | "template_declaration" => {
72 let name = function_name(node, source).unwrap_or_default();
73 (NodeKind::Function, Some(name), true)
74 }
75 "class_specifier" | "struct_specifier" | "interface_specifier" | "union_specifier" => {
76 let name = node
77 .child_by_field_name("name")
78 .map(|n| {
79 n.utf8_text(source.as_bytes())
80 .unwrap_or_default()
81 .to_string()
82 })
83 .unwrap_or_default();
84 (NodeKind::Class, Some(name), true)
85 }
86 "call_expression" => {
87 let name = node
88 .child_by_field_name("function")
89 .map(|n| {
90 n.utf8_text(source.as_bytes())
91 .unwrap_or_default()
92 .to_string()
93 })
94 .unwrap_or_default();
95 (NodeKind::CallSite, Some(name), false)
96 }
97 "declaration" => {
98 let name = node
99 .child_by_field_name("declarator")
100 .or_else(|| node.child_by_field_name("name"))
101 .map(|n| {
102 n.utf8_text(source.as_bytes())
103 .unwrap_or_default()
104 .to_string()
105 })
106 .unwrap_or_default();
107 if parent_kind_is(node, "parameter_list") {
108 (NodeKind::Parameter, Some(name), false)
109 } else {
110 (NodeKind::Variable, Some(name), false)
111 }
112 }
113 _ => {
114 let mut current_parent = parent_idx;
115 for child in node.children(&mut node.walk()) {
116 current_parent = traverse_node(child, source, facts, current_parent);
117 }
118 return current_parent;
119 }
120 };
121
122 let start = node.start_position();
123 let end = node.end_position();
124
125 let start_line = start.row + 1;
126 let end_line = std::cmp::max(end.row + 1, start_line);
127
128 let idx = facts.len();
129 facts.push(RawNode {
130 language: Language::CppTreeSitter,
131 kind: node_kind,
132 name,
133 usr: None,
134 start_line,
135 start_col: start.column,
136 end_line,
137 end_col: end.column,
138 children: Vec::new(),
139 source_file: None,
140 });
141
142 debug_assert!(
143 end_line >= start_line,
144 "end_line {} must be >= start_line {}",
145 end_line,
146 start_line
147 );
148
149 if let Some(pidx) = parent_idx {
150 facts[pidx].children.push(idx);
151 }
152
153 if is_container {
154 let new_parent = Some(idx);
155 let mut current_parent = new_parent;
156 for child in node.children(&mut node.walk()) {
157 current_parent = traverse_node(child, source, facts, current_parent);
158 }
159 new_parent
160 } else {
161 parent_idx
162 }
163}
164
165fn function_name(node: Node, source: &str) -> Option<String> {
167 if let Some(decl) = node.child_by_field_name("declarator") {
168 if decl.kind() == "function_declarator" {
169 if let Some(name_node) = decl.child_by_field_name("declarator") {
170 return Some(name_node.utf8_text(source.as_bytes()).ok()?.to_string());
171 }
172 return Some(decl.utf8_text(source.as_bytes()).ok()?.to_string());
173 }
174 return Some(decl.utf8_text(source.as_bytes()).ok()?.to_string());
175 }
176 let mut cursor = node.walk();
177 for child in node.children(&mut cursor) {
178 if child.kind() == "identifier" {
179 return child
180 .utf8_text(source.as_bytes())
181 .ok()
182 .map(|s| s.to_string());
183 }
184 }
185 None
186}
187
188fn parent_kind_is(node: Node, expected: &str) -> bool {
190 node.parent().is_some_and(|p| p.kind() == expected)
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196
197 #[test]
198 fn parse_simple_function() {
199 let facts = parse_cpp_file("void foo() {}").unwrap();
200 assert_eq!(facts.len(), 1);
201 assert_eq!(facts[0].kind, NodeKind::Function);
202 assert_eq!(facts[0].name.as_deref(), Some("foo"));
203 assert!(facts[0].end_line >= facts[0].start_line);
204 }
205
206 #[test]
207 fn parse_class() {
208 let code = "class MyClass { void bar(); };";
209 let facts = parse_cpp_file(code).unwrap();
210 assert!(facts
211 .iter()
212 .any(|n| n.kind == NodeKind::Class && n.name.as_deref() == Some("MyClass")));
213 }
214
215 #[test]
216 fn parse_template_class() {
217 let code = "template <typename T> class Container { T value; };";
218 let facts = parse_cpp_file(code).unwrap();
219 assert!(facts
220 .iter()
221 .any(|n| n.kind == NodeKind::Class && n.name.as_deref() == Some("Container")));
222 }
223
224 #[test]
225 fn parse_function_with_call() {
226 let code = "void bar() {} void baz() { bar(); }";
227 let facts = parse_cpp_file(code).unwrap();
228 let calls: Vec<_> = facts
229 .iter()
230 .filter(|n| n.kind == NodeKind::CallSite)
231 .collect();
232 assert!(!calls.is_empty());
233 }
234}