open Expr open Cse let c_type_of_var _ = "double" let rec expr_to_c = function | Const f -> string_of_float f | SymConst Pi -> "M_PI" | SymConst E -> "M_E" | Var v -> v | Add (e1, e2) -> Printf.sprintf "(%s + %s)" (expr_to_c e1) (expr_to_c e2) | Sub (e1, e2) -> Printf.sprintf "(%s - %s)" (expr_to_c e1) (expr_to_c e2) | Mul (e1, e2) -> Printf.sprintf "(%s * %s)" (expr_to_c e1) (expr_to_c e2) | Div (e1, e2) -> Printf.sprintf "(%s / %s)" (expr_to_c e1) (expr_to_c e2) | Pow (e1, e2) -> Printf.sprintf "pow(%s, %s)" (expr_to_c e1) (expr_to_c e2) | Neg e -> Printf.sprintf "(-%s)" (expr_to_c e) | Sin e -> Printf.sprintf "sin(%s)" (expr_to_c e) | Cos e -> Printf.sprintf "cos(%s)" (expr_to_c e) | Tan e -> Printf.sprintf "tan(%s)" (expr_to_c e) | Sinh e -> Printf.sprintf "sinh(%s)" (expr_to_c e) | Cosh e -> Printf.sprintf "cosh(%s)" (expr_to_c e) | Tanh e -> Printf.sprintf "tanh(%s)" (expr_to_c e) | Asin e -> Printf.sprintf "asin(%s)" (expr_to_c e) | Acos e -> Printf.sprintf "acos(%s)" (expr_to_c e) | Atan e -> Printf.sprintf "atan(%s)" (expr_to_c e) | Atan2 (e1, e2) -> Printf.sprintf "atan2(%s, %s)" (expr_to_c e1) (expr_to_c e2) | Exp e -> Printf.sprintf "exp(%s)" (expr_to_c e) | Ln e -> Printf.sprintf "log(%s)" (expr_to_c e) | Log (base_e, arg) -> Printf.sprintf "(log(%s) / log(%s))" (expr_to_c arg) (expr_to_c base_e) | Sqrt e -> Printf.sprintf "sqrt(%s)" (expr_to_c e) | Abs e -> Printf.sprintf "fabs(%s)" (expr_to_c e) let compile_to_c expr vars = let cse_result = common_subexpression_elimination expr in let params = String.concat ", " (List.map (fun v -> Printf.sprintf "double %s" v) vars) in let body = Buffer.create 1024 in Buffer.add_string body "#include \n\n"; Buffer.add_string body (Printf.sprintf "double compute(%s) {\n" params); List.iter (fun (name, e) -> Buffer.add_string body (Printf.sprintf " double %s = %s;\n" name (expr_to_c e)) ) cse_result.subexpressions; Buffer.add_string body (Printf.sprintf " return %s;\n" (expr_to_c cse_result.final_expr)); Buffer.add_string body "}\n"; Buffer.contents body let compile_to_cuda expr vars = let cse_result = common_subexpression_elimination expr in let params = String.concat ", " (List.map (fun v -> Printf.sprintf "double %s" v) vars) in let body = Buffer.create 1024 in Buffer.add_string body "__device__ double compute_device("; Buffer.add_string body params; Buffer.add_string body ") {\n"; List.iter (fun (name, e) -> Buffer.add_string body (Printf.sprintf " double %s = %s;\n" name (expr_to_c e)) ) cse_result.subexpressions; Buffer.add_string body (Printf.sprintf " return %s;\n" (expr_to_c cse_result.final_expr)); Buffer.add_string body "}\n\n"; Buffer.add_string body "__global__ void compute_kernel(double* input, double* output, int n) {\n"; Buffer.add_string body " int idx = blockIdx.x * blockDim.x + threadIdx.x;\n"; Buffer.add_string body " if (idx < n) {\n"; Buffer.add_string body (Printf.sprintf " output[idx] = compute_device(%s);\n" (String.concat ", " (List.mapi (fun i _ -> Printf.sprintf "input[idx * %d + %d]" (List.length vars) i) vars))); Buffer.add_string body " }\n"; Buffer.add_string body "}\n"; Buffer.contents body let compile_to_vectorized expr vars = let cse_result = common_subexpression_elimination expr in let params = String.concat ", " (List.map (fun v -> Printf.sprintf "const double* restrict %s" v) vars) in let body = Buffer.create 1024 in Buffer.add_string body "#include \n\n"; Buffer.add_string body (Printf.sprintf "void compute_vectorized(%s, double* restrict output, int n) {\n" params); Buffer.add_string body " #pragma omp simd\n"; Buffer.add_string body " for (int i = 0; i < n; i++) {\n"; List.iter (fun (name, e) -> let vec_expr = Str.global_replace (Str.regexp_string (List.hd vars)) (List.hd vars ^ "[i]") (expr_to_c e) in Buffer.add_string body (Printf.sprintf " double %s = %s;\n" name vec_expr) ) cse_result.subexpressions; Buffer.add_string body (Printf.sprintf " output[i] = %s;\n" (expr_to_c cse_result.final_expr)); Buffer.add_string body " }\n"; Buffer.add_string body "}\n"; Buffer.contents body