leibniz/lib/codegen.ml
2026-01-19 03:37:26 +00:00

95 lines
4.1 KiB
OCaml

open Expr
open Cse
let c_type_of_var _ = "double"
let rec expr_to_c = function
| Const f -> string_of_float f
| SymConst Pi -> "M_PI"
| SymConst E -> "M_E"
| Var v -> v
| Add (e1, e2) -> Printf.sprintf "(%s + %s)" (expr_to_c e1) (expr_to_c e2)
| Sub (e1, e2) -> Printf.sprintf "(%s - %s)" (expr_to_c e1) (expr_to_c e2)
| Mul (e1, e2) -> Printf.sprintf "(%s * %s)" (expr_to_c e1) (expr_to_c e2)
| Div (e1, e2) -> Printf.sprintf "(%s / %s)" (expr_to_c e1) (expr_to_c e2)
| Pow (e1, e2) -> Printf.sprintf "pow(%s, %s)" (expr_to_c e1) (expr_to_c e2)
| Neg e -> Printf.sprintf "(-%s)" (expr_to_c e)
| Sin e -> Printf.sprintf "sin(%s)" (expr_to_c e)
| Cos e -> Printf.sprintf "cos(%s)" (expr_to_c e)
| Tan e -> Printf.sprintf "tan(%s)" (expr_to_c e)
| Sinh e -> Printf.sprintf "sinh(%s)" (expr_to_c e)
| Cosh e -> Printf.sprintf "cosh(%s)" (expr_to_c e)
| Tanh e -> Printf.sprintf "tanh(%s)" (expr_to_c e)
| Asin e -> Printf.sprintf "asin(%s)" (expr_to_c e)
| Acos e -> Printf.sprintf "acos(%s)" (expr_to_c e)
| Atan e -> Printf.sprintf "atan(%s)" (expr_to_c e)
| Atan2 (e1, e2) -> Printf.sprintf "atan2(%s, %s)" (expr_to_c e1) (expr_to_c e2)
| Exp e -> Printf.sprintf "exp(%s)" (expr_to_c e)
| Ln e -> Printf.sprintf "log(%s)" (expr_to_c e)
| Log (base_e, arg) -> Printf.sprintf "(log(%s) / log(%s))" (expr_to_c arg) (expr_to_c base_e)
| Sqrt e -> Printf.sprintf "sqrt(%s)" (expr_to_c e)
| Abs e -> Printf.sprintf "fabs(%s)" (expr_to_c e)
let compile_to_c expr vars =
let cse_result = common_subexpression_elimination expr in
let params = String.concat ", " (List.map (fun v -> Printf.sprintf "double %s" v) vars) in
let body = Buffer.create 1024 in
Buffer.add_string body "#include <math.h>\n\n";
Buffer.add_string body (Printf.sprintf "double compute(%s) {\n" params);
List.iter (fun (name, e) ->
Buffer.add_string body (Printf.sprintf " double %s = %s;\n" name (expr_to_c e))
) cse_result.subexpressions;
Buffer.add_string body (Printf.sprintf " return %s;\n" (expr_to_c cse_result.final_expr));
Buffer.add_string body "}\n";
Buffer.contents body
let compile_to_cuda expr vars =
let cse_result = common_subexpression_elimination expr in
let params = String.concat ", " (List.map (fun v -> Printf.sprintf "double %s" v) vars) in
let body = Buffer.create 1024 in
Buffer.add_string body "__device__ double compute_device(";
Buffer.add_string body params;
Buffer.add_string body ") {\n";
List.iter (fun (name, e) ->
Buffer.add_string body (Printf.sprintf " double %s = %s;\n" name (expr_to_c e))
) cse_result.subexpressions;
Buffer.add_string body (Printf.sprintf " return %s;\n" (expr_to_c cse_result.final_expr));
Buffer.add_string body "}\n\n";
Buffer.add_string body "__global__ void compute_kernel(double* input, double* output, int n) {\n";
Buffer.add_string body " int idx = blockIdx.x * blockDim.x + threadIdx.x;\n";
Buffer.add_string body " if (idx < n) {\n";
Buffer.add_string body (Printf.sprintf " output[idx] = compute_device(%s);\n"
(String.concat ", " (List.mapi (fun i _ -> Printf.sprintf "input[idx * %d + %d]" (List.length vars) i) vars)));
Buffer.add_string body " }\n";
Buffer.add_string body "}\n";
Buffer.contents body
let compile_to_vectorized expr vars =
let cse_result = common_subexpression_elimination expr in
let params = String.concat ", " (List.map (fun v -> Printf.sprintf "const double* restrict %s" v) vars) in
let body = Buffer.create 1024 in
Buffer.add_string body "#include <math.h>\n\n";
Buffer.add_string body (Printf.sprintf "void compute_vectorized(%s, double* restrict output, int n) {\n" params);
Buffer.add_string body " #pragma omp simd\n";
Buffer.add_string body " for (int i = 0; i < n; i++) {\n";
List.iter (fun (name, e) ->
let vec_expr = Str.global_replace (Str.regexp_string (List.hd vars)) (List.hd vars ^ "[i]") (expr_to_c e) in
Buffer.add_string body (Printf.sprintf " double %s = %s;\n" name vec_expr)
) cse_result.subexpressions;
Buffer.add_string body (Printf.sprintf " output[i] = %s;\n" (expr_to_c cse_result.final_expr));
Buffer.add_string body " }\n";
Buffer.add_string body "}\n";
Buffer.contents body