1use crate::facts::RawNode;
7use icb_common::{IcbError, Language, NodeKind};
8use tree_sitter::Parser;
9
10use super::common::{child_of_kind, traverse_node};
11
12pub fn parse_rust(source: &str) -> Result<Vec<RawNode>, IcbError> {
14 let mut parser = Parser::new();
15 parser
16 .set_language(&tree_sitter_rust::language())
17 .map_err(|e| IcbError::Parse(format!("cannot set tree-sitter-rust language: {e}")))?;
18
19 let tree = parser
20 .parse(source, None)
21 .ok_or_else(|| IcbError::Parse("tree-sitter parse returned None for Rust source".into()))?;
22
23 let mut facts = Vec::new();
24
25 let classifier =
26 |node: &tree_sitter::Node, source: &str| -> Option<(NodeKind, Option<String>, bool)> {
27 match node.kind() {
28 "function_item" | "function_signature_item" => {
29 let name = child_of_kind(*node, "identifier")
30 .or_else(|| child_of_kind(*node, "name"))
31 .and_then(|n| n.utf8_text(source.as_bytes()).ok())
32 .map(|s| s.to_string());
33 Some((NodeKind::Function, name, true))
34 }
35 "impl_item" | "trait_item" | "struct_item" | "enum_item" | "union_item" => {
36 let name = child_of_kind(*node, "type_identifier")
37 .or_else(|| child_of_kind(*node, "identifier"))
38 .and_then(|n| n.utf8_text(source.as_bytes()).ok())
39 .map(|s| s.to_string());
40 Some((NodeKind::Class, name, true))
41 }
42 "call_expression" | "macro_invocation" => {
43 let name_node = child_of_kind(*node, "identifier")
44 .or_else(|| child_of_kind(*node, "field_expression"))
45 .or_else(|| child_of_kind(*node, "scoped_identifier"));
46 let name = name_node
47 .and_then(|n| n.utf8_text(source.as_bytes()).ok())
48 .map(|s| s.to_string());
49 Some((NodeKind::CallSite, name, false))
50 }
51 _ => None,
52 }
53 };
54
55 traverse_node(
56 tree.root_node(),
57 source,
58 &mut facts,
59 None,
60 Language::Rust,
61 &classifier,
62 );
63 Ok(facts)
64}
65
66#[cfg(test)]
67mod tests {
68 use super::*;
69 use icb_common::NodeKind;
70
71 #[test]
72 fn test_simple_function() {
73 let code = "fn foo() {}";
74 let facts = parse_rust(code).unwrap();
75 let funcs: Vec<_> = facts
76 .iter()
77 .filter(|n| n.kind == NodeKind::Function)
78 .collect();
79 assert_eq!(funcs.len(), 1);
80 assert_eq!(funcs[0].name.as_deref(), Some("foo"));
81 }
82
83 #[test]
84 fn test_method_in_impl() {
85 let code = "struct S; impl S { fn bar(&self) {} }";
86 let facts = parse_rust(code).unwrap();
87 let methods: Vec<_> = facts
88 .iter()
89 .filter(|n| n.kind == NodeKind::Function && n.name.as_deref() == Some("bar"))
90 .collect();
91 assert!(!methods.is_empty());
92 }
93
94 #[test]
95 fn test_call_expression() {
96 let code = "fn baz() { foo(); }";
97 let facts = parse_rust(code).unwrap();
98 let calls: Vec<_> = facts
99 .iter()
100 .filter(|n| n.kind == NodeKind::CallSite)
101 .collect();
102 assert_eq!(calls.len(), 1);
103 assert_eq!(calls[0].name.as_deref(), Some("foo"));
104 }
105
106 #[test]
107 fn test_struct_type() {
108 let code = "struct MyStruct {}";
109 let facts = parse_rust(code).unwrap();
110 let classes: Vec<_> = facts.iter().filter(|n| n.kind == NodeKind::Class).collect();
111 assert_eq!(classes.len(), 1);
112 assert_eq!(classes[0].name.as_deref(), Some("MyStruct"));
113 }
114}