95 lines
4.1 KiB
OCaml
95 lines
4.1 KiB
OCaml
open Expr
|
|
open Cse
|
|
|
|
let c_type_of_var _ = "double"
|
|
|
|
let rec expr_to_c = function
|
|
| Const f -> string_of_float f
|
|
| SymConst Pi -> "M_PI"
|
|
| SymConst E -> "M_E"
|
|
| Var v -> v
|
|
| Add (e1, e2) -> Printf.sprintf "(%s + %s)" (expr_to_c e1) (expr_to_c e2)
|
|
| Sub (e1, e2) -> Printf.sprintf "(%s - %s)" (expr_to_c e1) (expr_to_c e2)
|
|
| Mul (e1, e2) -> Printf.sprintf "(%s * %s)" (expr_to_c e1) (expr_to_c e2)
|
|
| Div (e1, e2) -> Printf.sprintf "(%s / %s)" (expr_to_c e1) (expr_to_c e2)
|
|
| Pow (e1, e2) -> Printf.sprintf "pow(%s, %s)" (expr_to_c e1) (expr_to_c e2)
|
|
| Neg e -> Printf.sprintf "(-%s)" (expr_to_c e)
|
|
| Sin e -> Printf.sprintf "sin(%s)" (expr_to_c e)
|
|
| Cos e -> Printf.sprintf "cos(%s)" (expr_to_c e)
|
|
| Tan e -> Printf.sprintf "tan(%s)" (expr_to_c e)
|
|
| Sinh e -> Printf.sprintf "sinh(%s)" (expr_to_c e)
|
|
| Cosh e -> Printf.sprintf "cosh(%s)" (expr_to_c e)
|
|
| Tanh e -> Printf.sprintf "tanh(%s)" (expr_to_c e)
|
|
| Asin e -> Printf.sprintf "asin(%s)" (expr_to_c e)
|
|
| Acos e -> Printf.sprintf "acos(%s)" (expr_to_c e)
|
|
| Atan e -> Printf.sprintf "atan(%s)" (expr_to_c e)
|
|
| Atan2 (e1, e2) -> Printf.sprintf "atan2(%s, %s)" (expr_to_c e1) (expr_to_c e2)
|
|
| Exp e -> Printf.sprintf "exp(%s)" (expr_to_c e)
|
|
| Ln e -> Printf.sprintf "log(%s)" (expr_to_c e)
|
|
| Log (base_e, arg) -> Printf.sprintf "(log(%s) / log(%s))" (expr_to_c arg) (expr_to_c base_e)
|
|
| Sqrt e -> Printf.sprintf "sqrt(%s)" (expr_to_c e)
|
|
| Abs e -> Printf.sprintf "fabs(%s)" (expr_to_c e)
|
|
|
|
let compile_to_c expr vars =
|
|
let cse_result = common_subexpression_elimination expr in
|
|
let params = String.concat ", " (List.map (fun v -> Printf.sprintf "double %s" v) vars) in
|
|
let body = Buffer.create 1024 in
|
|
|
|
Buffer.add_string body "#include <math.h>\n\n";
|
|
Buffer.add_string body (Printf.sprintf "double compute(%s) {\n" params);
|
|
|
|
List.iter (fun (name, e) ->
|
|
Buffer.add_string body (Printf.sprintf " double %s = %s;\n" name (expr_to_c e))
|
|
) cse_result.subexpressions;
|
|
|
|
Buffer.add_string body (Printf.sprintf " return %s;\n" (expr_to_c cse_result.final_expr));
|
|
Buffer.add_string body "}\n";
|
|
|
|
Buffer.contents body
|
|
|
|
let compile_to_cuda expr vars =
|
|
let cse_result = common_subexpression_elimination expr in
|
|
let params = String.concat ", " (List.map (fun v -> Printf.sprintf "double %s" v) vars) in
|
|
let body = Buffer.create 1024 in
|
|
|
|
Buffer.add_string body "__device__ double compute_device(";
|
|
Buffer.add_string body params;
|
|
Buffer.add_string body ") {\n";
|
|
|
|
List.iter (fun (name, e) ->
|
|
Buffer.add_string body (Printf.sprintf " double %s = %s;\n" name (expr_to_c e))
|
|
) cse_result.subexpressions;
|
|
|
|
Buffer.add_string body (Printf.sprintf " return %s;\n" (expr_to_c cse_result.final_expr));
|
|
Buffer.add_string body "}\n\n";
|
|
|
|
Buffer.add_string body "__global__ void compute_kernel(double* input, double* output, int n) {\n";
|
|
Buffer.add_string body " int idx = blockIdx.x * blockDim.x + threadIdx.x;\n";
|
|
Buffer.add_string body " if (idx < n) {\n";
|
|
Buffer.add_string body (Printf.sprintf " output[idx] = compute_device(%s);\n"
|
|
(String.concat ", " (List.mapi (fun i _ -> Printf.sprintf "input[idx * %d + %d]" (List.length vars) i) vars)));
|
|
Buffer.add_string body " }\n";
|
|
Buffer.add_string body "}\n";
|
|
|
|
Buffer.contents body
|
|
|
|
let compile_to_vectorized expr vars =
|
|
let cse_result = common_subexpression_elimination expr in
|
|
let params = String.concat ", " (List.map (fun v -> Printf.sprintf "const double* restrict %s" v) vars) in
|
|
let body = Buffer.create 1024 in
|
|
|
|
Buffer.add_string body "#include <math.h>\n\n";
|
|
Buffer.add_string body (Printf.sprintf "void compute_vectorized(%s, double* restrict output, int n) {\n" params);
|
|
Buffer.add_string body " #pragma omp simd\n";
|
|
Buffer.add_string body " for (int i = 0; i < n; i++) {\n";
|
|
|
|
List.iter (fun (name, e) ->
|
|
let vec_expr = Str.global_replace (Str.regexp_string (List.hd vars)) (List.hd vars ^ "[i]") (expr_to_c e) in
|
|
Buffer.add_string body (Printf.sprintf " double %s = %s;\n" name vec_expr)
|
|
) cse_result.subexpressions;
|
|
|
|
Buffer.add_string body (Printf.sprintf " output[i] = %s;\n" (expr_to_c cse_result.final_expr));
|
|
Buffer.add_string body " }\n";
|
|
Buffer.add_string body "}\n";
|
|
|
|
Buffer.contents body
|