Skip to main content

icb_parser/lang/
rust.rs

1//! Rust language parser using tree-sitter-rust.
2//!
3//! Extracts function declarations, method declarations, call expressions,
4//! and trait/struct/enum definitions from Rust source files.
5
6use crate::facts::RawNode;
7use icb_common::{IcbError, Language, NodeKind};
8use tree_sitter::Parser;
9
10use super::common::{child_of_kind, traverse_node};
11
12/// Parse Rust source code and return the extracted facts.
13pub fn parse_rust(source: &str) -> Result<Vec<RawNode>, IcbError> {
14    let mut parser = Parser::new();
15    parser
16        .set_language(&tree_sitter_rust::language())
17        .map_err(|e| IcbError::Parse(format!("cannot set tree-sitter-rust language: {e}")))?;
18
19    let tree = parser
20        .parse(source, None)
21        .ok_or_else(|| IcbError::Parse("tree-sitter parse returned None for Rust source".into()))?;
22
23    let mut facts = Vec::new();
24
25    let classifier =
26        |node: &tree_sitter::Node, source: &str| -> Option<(NodeKind, Option<String>, bool)> {
27            match node.kind() {
28                "function_item" | "function_signature_item" => {
29                    let name = child_of_kind(*node, "identifier")
30                        .or_else(|| child_of_kind(*node, "name"))
31                        .and_then(|n| n.utf8_text(source.as_bytes()).ok())
32                        .map(|s| s.to_string());
33                    Some((NodeKind::Function, name, true))
34                }
35                "impl_item" | "trait_item" | "struct_item" | "enum_item" | "union_item" => {
36                    let name = child_of_kind(*node, "type_identifier")
37                        .or_else(|| child_of_kind(*node, "identifier"))
38                        .and_then(|n| n.utf8_text(source.as_bytes()).ok())
39                        .map(|s| s.to_string());
40                    Some((NodeKind::Class, name, true))
41                }
42                "call_expression" | "macro_invocation" => {
43                    let name_node = child_of_kind(*node, "identifier")
44                        .or_else(|| child_of_kind(*node, "field_expression"))
45                        .or_else(|| child_of_kind(*node, "scoped_identifier"));
46                    let name = name_node
47                        .and_then(|n| n.utf8_text(source.as_bytes()).ok())
48                        .map(|s| s.to_string());
49                    Some((NodeKind::CallSite, name, false))
50                }
51                _ => None,
52            }
53        };
54
55    traverse_node(
56        tree.root_node(),
57        source,
58        &mut facts,
59        None,
60        Language::Rust,
61        &classifier,
62    );
63    Ok(facts)
64}
65
66#[cfg(test)]
67mod tests {
68    use super::*;
69    use icb_common::NodeKind;
70
71    #[test]
72    fn test_simple_function() {
73        let code = "fn foo() {}";
74        let facts = parse_rust(code).unwrap();
75        let funcs: Vec<_> = facts
76            .iter()
77            .filter(|n| n.kind == NodeKind::Function)
78            .collect();
79        assert_eq!(funcs.len(), 1);
80        assert_eq!(funcs[0].name.as_deref(), Some("foo"));
81    }
82
83    #[test]
84    fn test_method_in_impl() {
85        let code = "struct S; impl S { fn bar(&self) {} }";
86        let facts = parse_rust(code).unwrap();
87        let methods: Vec<_> = facts
88            .iter()
89            .filter(|n| n.kind == NodeKind::Function && n.name.as_deref() == Some("bar"))
90            .collect();
91        assert!(!methods.is_empty());
92    }
93
94    #[test]
95    fn test_call_expression() {
96        let code = "fn baz() { foo(); }";
97        let facts = parse_rust(code).unwrap();
98        let calls: Vec<_> = facts
99            .iter()
100            .filter(|n| n.kind == NodeKind::CallSite)
101            .collect();
102        assert_eq!(calls.len(), 1);
103        assert_eq!(calls[0].name.as_deref(), Some("foo"));
104    }
105
106    #[test]
107    fn test_struct_type() {
108        let code = "struct MyStruct {}";
109        let facts = parse_rust(code).unwrap();
110        let classes: Vec<_> = facts.iter().filter(|n| n.kind == NodeKind::Class).collect();
111        assert_eq!(classes.len(), 1);
112        assert_eq!(classes[0].name.as_deref(), Some("MyStruct"));
113    }
114}