1use crate::*;
6use std::io::Write;
7
8#[derive(Debug, PartialEq)]
10pub struct Arg {
11 pub name: &'static str,
12 pub arg_type: &'static str,
13 pub default: Option<&'static str>,
14}
15
16#[derive(Debug, PartialEq)]
18pub struct Func {
19 pub doc: &'static str,
20 pub rust_name: &'static str,
21 pub mod_name: &'static str,
22 pub r_name: &'static str,
23 pub args: Vec<Arg>,
24 pub return_type: &'static str,
25 pub func_ptr: *const u8,
26 pub hidden: bool,
27}
28
29#[derive(Debug, PartialEq)]
31pub struct Impl {
32 pub doc: &'static str,
33 pub name: &'static str,
34 pub methods: Vec<Func>,
35}
36
37#[derive(Debug, PartialEq)]
39pub struct Metadata {
40 pub name: &'static str,
41 pub functions: Vec<Func>,
42 pub impls: Vec<Impl>,
43}
44
45struct RArg {
46 name: String,
47 default: Option<&'static str>,
48}
49
50impl RArg {
51 fn is_self(&self) -> bool {
52 self.name == "self"
53 }
54
55 fn to_actual_arg(&self) -> String {
56 self.name.clone()
57 }
58
59 fn to_formal_arg(&self) -> String {
60 match self.default {
61 Some(default_val) => format!("{} = {}", self.name, default_val),
62 None => self.name.clone(),
63 }
64 }
65}
66
67impl From<&Arg> for RArg {
68 fn from(arg: &Arg) -> Self {
69 Self {
70 name: sanitize_identifier(arg.name),
71 default: arg.default,
72 }
73 }
74}
75
76impl From<Arg> for Robj {
77 fn from(val: Arg) -> Self {
78 use crate as extendr_api;
79 let mut result = List::from_values(&[r!(val.name), r!(val.arg_type)]);
80 result
81 .set_names(&["name", "arg_type"])
82 .expect("From<Arg> failed");
83 result.into()
84 }
85}
86
87impl From<Func> for Robj {
88 fn from(val: Func) -> Self {
89 use crate as extendr_api;
90 let mut result = List::from_values(&[
91 r!(val.doc),
92 r!(val.rust_name),
93 r!(val.mod_name),
94 r!(val.r_name),
95 r!(List::from_values(val.args)),
96 r!(val.return_type),
97 r!(val.hidden),
98 ]);
99 result
100 .set_names(&[
101 "doc",
102 "rust_name",
103 "mod_name",
104 "r_name",
105 "args",
106 "return.type",
107 "hidden",
108 ])
109 .expect("From<Func> failed");
110 result.into()
111 }
112}
113
114impl From<Impl> for Robj {
115 fn from(val: Impl) -> Self {
116 use crate as extendr_api;
117 let mut result = List::from_values(&[
118 r!(val.doc),
119 r!(val.name),
120 r!(List::from_values(val.methods)),
121 ]);
122 result
123 .set_names(&["doc", "name", "methods"])
124 .expect("From<Impl> failed");
125 result.into()
126 }
127}
128
129impl From<Metadata> for Robj {
130 fn from(val: Metadata) -> Self {
131 use crate as extendr_api;
132 let mut result = List::from_values(&[
133 r!(val.name),
134 r!(List::from_values(val.functions)),
135 r!(List::from_values(val.impls)),
136 ]);
137 result
138 .set_names(&["name", "functions", "impls"])
139 .expect("From<Metadata> failed");
140 result.into()
141 }
142}
143
144fn write_doc(w: &mut Vec<u8>, doc: &str) -> std::io::Result<()> {
145 if !doc.is_empty() {
146 write!(w, "#'")?;
147 for c in doc.chars() {
148 if c == '\n' {
149 write!(w, "\n#'")?;
150 } else {
151 write!(w, "{}", c)?;
152 }
153 }
154 writeln!(w)?;
155 }
156 Ok(())
157}
158
159fn sanitize_identifier(name: &str) -> String {
162 if name.starts_with('_') {
163 format!("`{}`", name)
164 } else if name.starts_with("r#") {
165 name.strip_prefix("r#").unwrap().into()
166 } else {
167 name.to_string()
168 }
169}
170
171fn join_str(input: impl Iterator<Item = String>, sep: &str) -> String {
172 input.collect::<Vec<String>>().join(sep)
173}
174
175fn write_function_wrapper(
177 w: &mut Vec<u8>,
178 func: &Func,
179 package_name: &str,
180 use_symbols: bool,
181) -> std::io::Result<()> {
182 if func.hidden {
183 return Ok(());
184 }
185
186 write_doc(w, func.doc)?;
187
188 let r_args: Vec<RArg> = func.args.iter().map(Into::into).collect();
189 let actual_args = r_args.iter().map(|a| a.to_actual_arg());
190 let formal_args = r_args.iter().map(|a| a.to_formal_arg());
191
192 if func.return_type == "()" {
193 write!(
194 w,
195 "{} <- function({}) invisible(.Call(",
196 sanitize_identifier(func.r_name),
197 join_str(formal_args, ", ")
198 )?;
199 } else {
200 write!(
201 w,
202 "{} <- function({}) .Call(",
203 sanitize_identifier(func.r_name),
204 join_str(formal_args, ", ")
205 )?;
206 }
207
208 if use_symbols {
209 write!(w, "wrap__{}", func.mod_name)?;
210 } else {
211 write!(w, "\"wrap__{}\"", func.mod_name)?;
212 }
213
214 if !func.args.is_empty() {
215 write!(w, ", {}", join_str(actual_args, ", "))?;
216 }
217
218 if !use_symbols {
219 write!(w, ", PACKAGE = \"{}\"", package_name)?;
220 }
221
222 if func.return_type == "()" {
223 writeln!(w, "))\n")?;
224 } else {
225 writeln!(w, ")\n")?;
226 }
227
228 Ok(())
229}
230
231fn write_method_wrapper(
233 w: &mut Vec<u8>,
234 func: &Func,
235 package_name: &str,
236 use_symbols: bool,
237 class_name: &str,
238) -> std::io::Result<()> {
239 if func.hidden {
240 return Ok(());
241 }
242
243 let r_args: Vec<RArg> = func.args.iter().map(Into::into).collect();
244 let actual_args = r_args.iter().map(|a| a.to_actual_arg());
245
246 let formal_args = r_args
249 .iter()
250 .skip_while(|a| a.is_self())
251 .map(|a| a.to_formal_arg());
252
253 if func.return_type == "()" {
256 write!(
257 w,
258 "{}${} <- function({}) invisible(.Call(",
259 sanitize_identifier(class_name),
260 sanitize_identifier(func.r_name),
261 join_str(formal_args, ", ")
262 )?;
263 } else {
264 write!(
265 w,
266 "{}${} <- function({}) .Call(",
267 sanitize_identifier(class_name),
268 sanitize_identifier(func.r_name),
269 join_str(formal_args, ", ")
270 )?;
271 }
272
273 if use_symbols {
275 write!(w, "wrap__{}__{}", class_name, func.mod_name)?;
276 } else {
277 write!(w, "\"wrap__{}__{}\"", class_name, func.mod_name)?;
278 }
279
280 if actual_args.len() != 0 {
281 write!(w, ", {}", join_str(actual_args, ", "))?;
282 }
283
284 if !use_symbols {
285 write!(w, ", PACKAGE = \"{}\"", package_name)?;
286 }
287
288 if func.return_type == "()" {
289 writeln!(w, "))\n")?;
290 } else {
291 writeln!(w, ")\n")?;
292 }
293
294 Ok(())
295}
296
297fn write_impl_wrapper(
299 w: &mut Vec<u8>,
300 name: &str,
301 impls: &[Impl],
302 package_name: &str,
303 use_symbols: bool,
304) -> std::io::Result<()> {
305 let mut exported = false;
306 {
307 for imp in impls.iter().filter(|imp| imp.name == name) {
308 if !exported {
309 exported = imp.doc.contains("@export");
310 }
311 write_doc(w, imp.doc)?;
312 }
313 }
314
315 let imp_name_fixed = sanitize_identifier(name);
316
317 writeln!(w, "{} <- new.env(parent = emptyenv())\n", imp_name_fixed)?;
319
320 for imp in impls.iter().filter(|imp| imp.name == name) {
321 for func in &imp.methods {
322 write_method_wrapper(w, func, package_name, use_symbols, imp.name)?;
325 }
326 }
327
328 if exported {
329 writeln!(w, "#' @rdname {}", name)?;
330 writeln!(w, "#' @usage NULL")?;
331 }
332
333 writeln!(w, "#' @export")?;
338
339 writeln!(w, "`$.{}` <- function (self, name) {{ func <- {}[[name]]; environment(func) <- environment(); func }}\n", name, imp_name_fixed)?;
343
344 writeln!(w, "#' @export")?;
345 writeln!(w, "`[[.{}` <- `$.{}`\n", name, name)?;
346
347 Ok(())
348}
349
350impl Metadata {
351 pub fn make_r_wrappers(
352 &self,
353 use_symbols: bool,
354 package_name: &str,
355 ) -> std::io::Result<String> {
356 let mut w = Vec::new();
357
358 writeln!(
359 w,
360 r#"# Generated by extendr: Do not edit by hand
361#
362# This file was created with the following call:
363# .Call("wrap__make_{}_wrappers", use_symbols = {}, package_name = "{}")
364"#,
365 self.name,
366 if use_symbols { "TRUE" } else { "FALSE" },
367 package_name
368 )?;
369
370 if use_symbols {
371 writeln!(w, "#' @usage NULL")?;
372 writeln!(w, "#' @useDynLib {}, .registration = TRUE", package_name)?;
373 writeln!(w, "NULL")?;
374 writeln!(w)?;
375 }
376
377 for func in &self.functions {
378 write_function_wrapper(&mut w, func, package_name, use_symbols)?;
379 }
380
381 for name in self.impl_names() {
382 write_impl_wrapper(&mut w, name, &self.impls, package_name, use_symbols)?;
383 }
384
385 unsafe { Ok(String::from_utf8_unchecked(w)) }
386 }
387
388 fn impl_names<'a>(&'a self) -> Vec<&'a str> {
389 let mut vec: Vec<&str> = vec![];
390 for impls in &self.impls {
391 if !vec.contains(&impls.name) {
392 vec.push(&impls.name)
393 }
394 }
395 vec
396 }
397}