diff --git a/bin/main.ml b/bin/main.ml index ba95c04..0f2abb7 100644 --- a/bin/main.ml +++ b/bin/main.ml @@ -46,7 +46,47 @@ let parse_file path = Affinescript.Span.pp_short span msg; `Error (false, "Parse error") -(** Check a file *) +(** Run a file through the interpreter *) +let run_file path = + try + let prog = Affinescript.Parse_driver.parse_file path in + let env = Affinescript.Value.empty_env () in + (* Load stdlib prelude *) + Affinescript.Stdlib.load_prelude env; + Affinescript.Eval.eval_program env prog; + `Ok () + with + | Affinescript.Lexer.Lexer_error (msg, pos) -> + Format.eprintf "@[%s:%d:%d: lexer error: %s@]@." path pos.Affinescript.Span.line pos.Affinescript.Span.col msg; + `Error (false, "Lexer error") + | Affinescript.Parse_driver.Parse_error (msg, span) -> + Format.eprintf "@[%a: parse error: %s@]@." + Affinescript.Span.pp_short span msg; + `Error (false, "Parse error") + | Affinescript.Eval.Runtime_error (msg, span_opt) -> + (match span_opt with + | Some span -> Format.eprintf "@[%a: runtime error: %s@]@." Affinescript.Span.pp_short span msg + | None -> Format.eprintf "@[runtime error: %s@]@." msg); + `Error (false, "Runtime error") + | Failure msg -> + Format.eprintf "@[error: %s@]@." msg; + `Error (false, "Error") + +(** Evaluate an expression from command line *) +let eval_expr expr_str = + let env = Affinescript.Value.empty_env () in + Affinescript.Stdlib.load_prelude env; + match Affinescript.Repl.eval_string ~env expr_str with + | Ok v -> Format.printf "%s@." (Affinescript.Value.show v); `Ok () + | Error msg -> Format.eprintf "Error: %s@." msg; `Error (false, msg) + +(** Start the REPL *) +let repl_run file_opt = + match file_opt with + | Some file -> Affinescript.Repl.run_with_file file; `Ok () + | None -> Affinescript.Repl.run (); `Ok () + +(** Check a file (type check - placeholder) *) let check_file path = let source = read_file path in let _ = source in @@ -68,6 +108,12 @@ open Cmdliner let path_arg = Arg.(required & pos 0 (some file) None & info [] ~docv:"FILE" ~doc:"Input file") +let optional_path_arg = + Arg.(value & pos 0 (some file) None & info [] ~docv:"FILE" ~doc:"Input file to load") + +let expr_arg = + Arg.(required & pos 0 (some string) None & info [] ~docv:"EXPR" ~doc:"Expression to evaluate") + let output_arg = Arg.(value & opt string "out.wasm" & info ["o"; "output"] ~docv:"FILE" ~doc:"Output file") @@ -81,6 +127,21 @@ let parse_cmd = let info = Cmd.info "parse" ~doc in Cmd.v info Term.(ret (const parse_file $ path_arg)) +let run_cmd = + let doc = "Run a file through the interpreter" in + let info = Cmd.info "run" ~doc in + Cmd.v info Term.(ret (const run_file $ path_arg)) + +let eval_cmd = + let doc = "Evaluate an expression" in + let info = Cmd.info "eval" ~doc in + Cmd.v info Term.(ret (const eval_expr $ expr_arg)) + +let repl_cmd = + let doc = "Start the interactive REPL" in + let info = Cmd.info "repl" ~doc in + Cmd.v info Term.(ret (const repl_run $ optional_path_arg)) + let check_cmd = let doc = "Type check a file" in let info = Cmd.info "check" ~doc in @@ -93,8 +154,25 @@ let compile_cmd = let default_cmd = let doc = "The AffineScript compiler" in - let info = Cmd.info "affinescript" ~version ~doc in + let sdocs = Manpage.s_common_options in + let man = [ + `S Manpage.s_description; + `P "AffineScript is a systems programming language with affine types, \ + dependent types, row polymorphism, and extensible effects."; + `S Manpage.s_commands; + `P "Use $(b,affinescript COMMAND --help) for help on a specific command."; + `S Manpage.s_examples; + `P "Start the REPL:"; + `Pre " $(b,affinescript repl)"; + `P "Run a file:"; + `Pre " $(b,affinescript run hello.afs)"; + `P "Evaluate an expression:"; + `Pre " $(b,affinescript eval \"1 + 2 * 3\")"; + ] in + let info = Cmd.info "affinescript" ~version ~doc ~sdocs ~man in let default = Term.(ret (const (`Help (`Pager, None)))) in - Cmd.group info ~default [lex_cmd; parse_cmd; check_cmd; compile_cmd] + Cmd.group info ~default [ + repl_cmd; run_cmd; eval_cmd; lex_cmd; parse_cmd; check_cmd; compile_cmd + ] let () = exit (Cmd.eval default_cmd) diff --git a/examples/enums.afs b/examples/enums.afs new file mode 100644 index 0000000..0c4c067 --- /dev/null +++ b/examples/enums.afs @@ -0,0 +1,68 @@ +// Enums and algebraic data types in AffineScript + +// Simple enum +enum Color { + Red, + Green, + Blue, +} + +// Enum with data (sum type) +enum Option[T] { + None, + Some(T), +} + +enum Result[T, E] { + Ok(T), + Err(E), +} + +// Using enums with pattern matching +fn describe_color(c: Color) -> String { + match c { + Color::Red => "The color of fire", + Color::Green => "The color of nature", + Color::Blue => "The color of the sky", + } +} + +// Option handling +fn safe_divide(a: Int, b: Int) -> Option[Int] { + if b == 0 { + None + } else { + Some(a / b) + } +} + +fn show_result(opt: Option[Int]) -> String { + match opt { + Some(x) => "Result: " + str(x), + None => "Division by zero!", + } +} + +println(describe_color(Color::Blue)); +println(show_result(safe_divide(10, 2))); +println(show_result(safe_divide(10, 0))); + +// Result for error handling +enum MathError { + DivisionByZero, + NegativeRoot, +} + +fn safe_sqrt(x: Int) -> Result[Int, MathError] { + if x < 0 { + Err(MathError::NegativeRoot) + } else { + // Simplified integer square root + let mut guess = x / 2; + if guess == 0 { guess = 1; } + while guess * guess > x { + guess = (guess + x / guess) / 2; + } + Ok(guess) + } +} diff --git a/examples/factorial.afs b/examples/factorial.afs new file mode 100644 index 0000000..72928cc --- /dev/null +++ b/examples/factorial.afs @@ -0,0 +1,19 @@ +// Factorial calculation in AffineScript + +fn factorial(n: Int) -> Int { + if n <= 1 { 1 } + else { n * factorial(n - 1) } +} + +// Using fold for a functional approach +fn factorial_fold(n: Int) -> Int { + if n <= 1 { 1 } + else { + fold(\acc: Int, x: Int -> acc * x, 1, range(1, n + 1)) + } +} + +println("Factorials:"); +for i in range(1, 11) { + println(str(i) + "! = " + str(factorial(i))); +} diff --git a/examples/fibonacci.afs b/examples/fibonacci.afs new file mode 100644 index 0000000..3c323ab --- /dev/null +++ b/examples/fibonacci.afs @@ -0,0 +1,36 @@ +// Fibonacci sequence in AffineScript + +// Recursive implementation +fn fib_recursive(n: Int) -> Int { + if n <= 1 { n } + else { fib_recursive(n - 1) + fib_recursive(n - 2) } +} + +// Iterative implementation (more efficient) +fn fib_iterative(n: Int) -> Int { + if n <= 1 { n } + else { + let mut a = 0; + let mut b = 1; + let mut i = 2; + while i <= n { + let temp = a + b; + a = b; + b = temp; + i += 1; + } + b + } +} + +// Print first 10 Fibonacci numbers +println("Fibonacci sequence (recursive):"); +for i in range(10) { + println(str(fib_recursive(i))); +} + +println(""); +println("Fibonacci sequence (iterative):"); +for i in range(10) { + println(str(fib_iterative(i))); +} diff --git a/examples/hello.afs b/examples/hello.afs new file mode 100644 index 0000000..7aa2ce7 --- /dev/null +++ b/examples/hello.afs @@ -0,0 +1,7 @@ +// Hello World in AffineScript + +fn main() { + println("Hello, AffineScript!"); +} + +main() diff --git a/examples/higher_order.afs b/examples/higher_order.afs new file mode 100644 index 0000000..4c7fb0e --- /dev/null +++ b/examples/higher_order.afs @@ -0,0 +1,33 @@ +// Higher-order functions in AffineScript + +// Map: transform each element +let numbers = [1, 2, 3, 4, 5]; +let doubled = map(\x: Int -> x * 2, numbers); +println("Doubled: " + str(doubled)); + +// Filter: keep elements matching predicate +let evens = filter(\x: Int -> x % 2 == 0, numbers); +println("Evens: " + str(evens)); + +// Fold: reduce to single value +let sum = fold(\acc: Int, x: Int -> acc + x, 0, numbers); +println("Sum: " + str(sum)); + +// Composition example +fn compose(f: fn(Int) -> Int, g: fn(Int) -> Int) -> fn(Int) -> Int { + \x: Int -> f(g(x)) +} + +let add_one = \x: Int -> x + 1; +let times_two = \x: Int -> x * 2; +let add_then_double = compose(times_two, add_one); + +println("(5 + 1) * 2 = " + str(add_then_double(5))); + +// Pipeline pattern +let result = numbers + |> map(\x: Int -> x * x) // Square each + |> filter(\x: Int -> x > 5) // Keep > 5 + |> fold(\a: Int, x: Int -> a + x, 0); // Sum + +println("Sum of squares > 5: " + str(result)); diff --git a/examples/records.afs b/examples/records.afs new file mode 100644 index 0000000..f0db9ff --- /dev/null +++ b/examples/records.afs @@ -0,0 +1,35 @@ +// Records and pattern matching in AffineScript + +// Define a person record +struct Person { + name: String, + age: Int, + email: String, +} + +// Create some records using anonymous record syntax +let alice = { name: "Alice", age: 30, email: "alice@example.com" }; +let bob = { name: "Bob", age: 25, email: "bob@example.com" }; + +println("Person: " + alice.name + ", age " + str(alice.age)); + +// Record update syntax +let older_alice = { ..alice, age: alice.age + 1 }; +println("After birthday: " + older_alice.name + " is " + str(older_alice.age)); + +// Pattern matching on records +fn greet(person: { name: String, ..r }) -> String { + "Hello, " + person.name + "!" +} + +println(greet(alice)); +println(greet(bob)); + +// Row polymorphism - function works with any record that has 'name' field +fn get_name(r: { name: String, ..rest }) -> String { + r.name +} + +let company = { name: "Acme Corp", founded: 1990 }; +println("Company: " + get_name(company)); +println("Person: " + get_name(alice)); diff --git a/lib/borrow.ml b/lib/borrow.ml new file mode 100644 index 0000000..c463b67 --- /dev/null +++ b/lib/borrow.ml @@ -0,0 +1,439 @@ +(** Borrow checker for AffineScript - enforces affine type rules *) + +open Ast + +(** Usage state for a binding *) +type usage_state = + | Unused (** Never used *) + | Used of Span.t (** Used exactly once at this location *) + | Moved of Span.t (** Ownership transferred (moved) *) + | Borrowed of Span.t (** Borrowed (referenced) *) + | MutBorrowed of Span.t (** Mutably borrowed *) + | MultipleUses (** Used more than once (error for linear) *) +[@@deriving show] + +(** Binding information *) +type binding_info = { + bi_name: string; + bi_span: Span.t option; + bi_quantity: Types.quantity; + bi_mutable: bool; + bi_ownership: ownership option; + mutable bi_state: usage_state; +} + +(** Borrow errors *) +type error = + | UseAfterMove of string * Span.t * Span.t (** name, use location, move location *) + | LinearUsedMultipleTimes of string * Span.t * Span.t + | LinearNotUsed of string * Span.t option + | CannotMoveFromBorrow of string * Span.t + | CannotMutateBorrowed of string * Span.t + | DoubleMutableBorrow of string * Span.t * Span.t + | BorrowedWhileMutablyBorrowed of string * Span.t * Span.t + | EscapingBorrow of string * Span.t + | UninitializedUse of string * Span.t + +exception Borrow_error of error + +let error_to_string = function + | UseAfterMove (name, _, _) -> + Printf.sprintf "Use of moved value: %s" name + | LinearUsedMultipleTimes (name, _, _) -> + Printf.sprintf "Linear value '%s' used multiple times" name + | LinearNotUsed (name, _) -> + Printf.sprintf "Linear value '%s' not used" name + | CannotMoveFromBorrow (name, _) -> + Printf.sprintf "Cannot move out of borrowed value: %s" name + | CannotMutateBorrowed (name, _) -> + Printf.sprintf "Cannot mutate borrowed value: %s" name + | DoubleMutableBorrow (name, _, _) -> + Printf.sprintf "Cannot borrow '%s' mutably more than once" name + | BorrowedWhileMutablyBorrowed (name, _, _) -> + Printf.sprintf "Cannot borrow '%s' while mutably borrowed" name + | EscapingBorrow (name, _) -> + Printf.sprintf "Borrowed value '%s' escapes its scope" name + | UninitializedUse (name, _) -> + Printf.sprintf "Use of uninitialized value: %s" name + +(** Borrow checker context *) +type context = { + ctx_bindings: (string, binding_info) Hashtbl.t; + ctx_parent: context option; + ctx_borrows: (string, Span.t * bool) Hashtbl.t; (** name -> (span, is_mutable) *) + ctx_errors: error list ref; + ctx_in_loop: bool; +} + +let create_context () = { + ctx_bindings = Hashtbl.create 32; + ctx_parent = None; + ctx_borrows = Hashtbl.create 16; + ctx_errors = ref []; + ctx_in_loop = false; +} + +let child_context parent = { + ctx_bindings = Hashtbl.create 16; + ctx_parent = Some parent; + ctx_borrows = Hashtbl.create 8; + ctx_errors = parent.ctx_errors; + ctx_in_loop = parent.ctx_in_loop; +} + +let loop_context parent = { + ctx_bindings = Hashtbl.create 16; + ctx_parent = Some parent; + ctx_borrows = Hashtbl.create 8; + ctx_errors = parent.ctx_errors; + ctx_in_loop = true; +} + +let add_error ctx err = + ctx.ctx_errors := err :: !(ctx.ctx_errors) + +(** Add a binding to context *) +let add_binding ctx name ~span ~quantity ~mutable_ ~ownership = + let info = { + bi_name = name; + bi_span = span; + bi_quantity = quantity; + bi_mutable = mutable_; + bi_ownership = ownership; + bi_state = Unused; + } in + Hashtbl.replace ctx.ctx_bindings name info + +(** Look up a binding *) +let rec lookup_binding ctx name = + match Hashtbl.find_opt ctx.ctx_bindings name with + | Some info -> Some info + | None -> Option.bind ctx.ctx_parent (fun p -> lookup_binding p name) + +(** Record a use of a binding *) +let record_use ctx name span = + match lookup_binding ctx name with + | None -> () (* Unknown binding - skip *) + | Some info -> + match info.bi_state, info.bi_quantity with + | Moved move_span, _ -> + add_error ctx (UseAfterMove (name, span, move_span)) + | Used use_span, Types.QOne -> + add_error ctx (LinearUsedMultipleTimes (name, span, use_span)); + info.bi_state <- MultipleUses + | _, _ -> + info.bi_state <- Used span + +(** Record a move of a binding *) +let record_move ctx name span = + match lookup_binding ctx name with + | None -> () + | Some info -> + match info.bi_state with + | Moved move_span -> + add_error ctx (UseAfterMove (name, span, move_span)) + | Borrowed borrow_span -> + add_error ctx (CannotMoveFromBorrow (name, borrow_span)) + | MutBorrowed borrow_span -> + add_error ctx (CannotMoveFromBorrow (name, borrow_span)) + | _ -> + info.bi_state <- Moved span + +(** Record a borrow *) +let record_borrow ctx name span ~mutable_ = + match lookup_binding ctx name with + | None -> () + | Some info -> + (* Check existing borrows *) + (match Hashtbl.find_opt ctx.ctx_borrows name with + | Some (existing_span, true) when mutable_ -> + add_error ctx (DoubleMutableBorrow (name, span, existing_span)) + | Some (existing_span, true) -> + add_error ctx (BorrowedWhileMutablyBorrowed (name, span, existing_span)) + | Some (existing_span, false) when mutable_ -> + add_error ctx (BorrowedWhileMutablyBorrowed (name, span, existing_span)) + | _ -> ()); + Hashtbl.replace ctx.ctx_borrows name (span, mutable_); + info.bi_state <- if mutable_ then MutBorrowed span else Borrowed span + +(** Check that all linear bindings are used at scope exit *) +let check_linear_used ctx = + Hashtbl.iter (fun name info -> + if info.bi_quantity = Types.QOne then + match info.bi_state with + | Unused -> add_error ctx (LinearNotUsed (name, info.bi_span)) + | _ -> () + ) ctx.ctx_bindings + +(** Get span from expression *) +let expr_span = function + | ExprSpan (_, span) -> Some span + | ExprVar id -> Some id.span + | ExprLit (LitInt (_, span)) -> Some span + | ExprLit (LitFloat (_, span)) -> Some span + | ExprLit (LitBool (_, span)) -> Some span + | ExprLit (LitChar (_, span)) -> Some span + | ExprLit (LitString (_, span)) -> Some span + | ExprLit (LitUnit span) -> Some span + | _ -> None + +(** Convert AST quantity to internal quantity *) +let convert_quantity = function + | Some Ast.QZero -> Types.QZero + | Some Ast.QOne -> Types.QOne + | Some Ast.QOmega -> Types.QOmega + | None -> Types.QOmega (* Default to unrestricted *) + +(** Check an expression *) +let rec check_expr ctx expr = + match expr with + | ExprSpan (e, _) -> check_expr ctx e + + | ExprLit _ -> () + + | ExprVar id -> + record_use ctx id.name id.span + + | ExprLet { el_mut; el_pat; el_value; el_body; _ } -> + check_expr ctx el_value; + add_pattern_bindings ctx el_pat ~mutable_:el_mut; + Option.iter (check_expr ctx) el_body + + | ExprIf { ei_cond; ei_then; ei_else } -> + check_expr ctx ei_cond; + let then_ctx = child_context ctx in + check_expr then_ctx ei_then; + check_linear_used then_ctx; + Option.iter (fun e -> + let else_ctx = child_context ctx in + check_expr else_ctx e; + check_linear_used else_ctx + ) ei_else + + | ExprMatch { em_scrutinee; em_arms } -> + check_expr ctx em_scrutinee; + List.iter (fun arm -> + let arm_ctx = child_context ctx in + add_pattern_bindings arm_ctx arm.ma_pat ~mutable_:false; + Option.iter (check_expr arm_ctx) arm.ma_guard; + check_expr arm_ctx arm.ma_body; + check_linear_used arm_ctx + ) em_arms + + | ExprLambda { elam_params; elam_body; _ } -> + let fn_ctx = child_context ctx in + List.iter (fun p -> + add_binding fn_ctx p.p_name.name + ~span:(Some p.p_name.span) + ~quantity:(convert_quantity p.p_quantity) + ~mutable_:false + ~ownership:p.p_ownership + ) elam_params; + check_expr fn_ctx elam_body; + check_linear_used fn_ctx + + | ExprApp (func, args) -> + check_expr ctx func; + (* Arguments may be moved to the function *) + List.iter (fun arg -> + check_expr ctx arg; + (* If arg is a variable, consider it moved (conservative) *) + match arg with + | ExprVar id -> record_move ctx id.name id.span + | _ -> () + ) args + + | ExprField (e, _) -> + check_expr ctx e + + | ExprTupleIndex (e, _) -> + check_expr ctx e + + | ExprIndex (e1, e2) -> + check_expr ctx e1; + check_expr ctx e2 + + | ExprTuple exprs -> + List.iter (check_expr ctx) exprs + + | ExprArray exprs -> + List.iter (check_expr ctx) exprs + + | ExprRecord { er_fields; er_spread } -> + List.iter (fun (id, expr_opt) -> + match expr_opt with + | Some e -> check_expr ctx e + | None -> record_use ctx id.name id.span + ) er_fields; + Option.iter (check_expr ctx) er_spread + + | ExprRowRestrict (e, _) -> + check_expr ctx e + + | ExprBinary (e1, _, e2) -> + check_expr ctx e1; + check_expr ctx e2 + + | ExprUnary (op, e) -> + (match op with + | OpRef -> + (* Creating a reference borrows the value *) + (match e with + | ExprVar id -> + let span = match expr_span expr with Some s -> s | None -> id.span in + record_borrow ctx id.name span ~mutable_:false + | _ -> check_expr ctx e) + | OpDeref -> check_expr ctx e + | _ -> check_expr ctx e) + + | ExprBlock blk -> + check_block ctx blk + + | ExprReturn e_opt -> + Option.iter (check_expr ctx) e_opt + + | ExprTry { et_body; et_catch; et_finally } -> + check_block ctx et_body; + Option.iter (fun arms -> + List.iter (fun arm -> + let arm_ctx = child_context ctx in + add_pattern_bindings arm_ctx arm.ma_pat ~mutable_:false; + check_expr arm_ctx arm.ma_body; + check_linear_used arm_ctx + ) arms + ) et_catch; + Option.iter (check_block ctx) et_finally + + | ExprHandle { eh_body; eh_handlers } -> + check_expr ctx eh_body; + List.iter (fun h -> + match h with + | HandlerReturn (pat, e) -> + let h_ctx = child_context ctx in + add_pattern_bindings h_ctx pat ~mutable_:false; + check_expr h_ctx e; + check_linear_used h_ctx + | HandlerOp (_, pats, e) -> + let h_ctx = child_context ctx in + List.iter (fun p -> add_pattern_bindings h_ctx p ~mutable_:false) pats; + check_expr h_ctx e; + check_linear_used h_ctx + ) eh_handlers + + | ExprResume e_opt -> + Option.iter (check_expr ctx) e_opt + + | ExprUnsafe ops -> + List.iter (fun op -> + match op with + | UnsafeRead e | UnsafeForget e -> check_expr ctx e + | UnsafeWrite (e1, e2) | UnsafeOffset (e1, e2) -> + check_expr ctx e1; + check_expr ctx e2 + | UnsafeTransmute (_, _, e) -> check_expr ctx e + | UnsafeAssume _ -> () + ) ops + + | ExprVariant _ -> () + +and check_block ctx { blk_stmts; blk_expr } = + let block_ctx = child_context ctx in + List.iter (check_stmt block_ctx) blk_stmts; + Option.iter (check_expr block_ctx) blk_expr; + check_linear_used block_ctx + +and check_stmt ctx = function + | StmtLet { sl_mut; sl_pat; sl_value; _ } -> + check_expr ctx sl_value; + add_pattern_bindings ctx sl_pat ~mutable_:sl_mut + + | StmtExpr e -> + check_expr ctx e + + | StmtAssign (target, _, value) -> + check_expr ctx value; + (* Check target is mutable *) + (match target with + | ExprVar id -> + (match lookup_binding ctx id.name with + | Some info when not info.bi_mutable -> + add_error ctx (CannotMutateBorrowed (id.name, id.span)) + | _ -> ()) + | _ -> check_expr ctx target) + + | StmtWhile (cond, body) -> + (* In a loop, variables might be used multiple times *) + let loop_ctx = loop_context ctx in + check_expr loop_ctx cond; + check_block loop_ctx body + + | StmtFor (pat, iter, body) -> + check_expr ctx iter; + let loop_ctx = loop_context ctx in + add_pattern_bindings loop_ctx pat ~mutable_:false; + check_block loop_ctx body + +and add_pattern_bindings ctx pat ~mutable_ = + match pat with + | PatWildcard _ -> () + | PatVar id -> + add_binding ctx id.name + ~span:(Some id.span) + ~quantity:Types.QOmega + ~mutable_ + ~ownership:None + | PatLit _ -> () + | PatCon (_, pats) -> + List.iter (fun p -> add_pattern_bindings ctx p ~mutable_) pats + | PatTuple pats -> + List.iter (fun p -> add_pattern_bindings ctx p ~mutable_) pats + | PatRecord (fields, _) -> + List.iter (fun (id, pat_opt) -> + match pat_opt with + | Some p -> add_pattern_bindings ctx p ~mutable_ + | None -> + add_binding ctx id.name + ~span:(Some id.span) + ~quantity:Types.QOmega + ~mutable_ + ~ownership:None + ) fields + | PatOr (p1, p2) -> + add_pattern_bindings ctx p1 ~mutable_; + add_pattern_bindings ctx p2 ~mutable_ + | PatAs (id, p) -> + add_binding ctx id.name + ~span:(Some id.span) + ~quantity:Types.QOmega + ~mutable_ + ~ownership:None; + add_pattern_bindings ctx p ~mutable_ + +(** Check a function declaration *) +let check_fn_decl ctx (decl: fn_decl) = + let fn_ctx = child_context ctx in + (* Add parameters *) + List.iter (fun p -> + add_binding fn_ctx p.p_name.name + ~span:(Some p.p_name.span) + ~quantity:(convert_quantity p.p_quantity) + ~mutable_:false + ~ownership:p.p_ownership + ) decl.fd_params; + (* Check body *) + (match decl.fd_body with + | FnBlock blk -> check_block fn_ctx blk + | FnExpr e -> check_expr fn_ctx e); + check_linear_used fn_ctx + +(** Check a program *) +let check_program (prog: program) = + let ctx = create_context () in + (* Check all function declarations *) + List.iter (fun decl -> + match decl with + | TopFn fd -> check_fn_decl ctx fd + | TopConst { tc_value; _ } -> check_expr ctx tc_value + | _ -> () + ) prog.prog_decls; + List.rev !(ctx.ctx_errors) diff --git a/lib/codegen.ml b/lib/codegen.ml new file mode 100644 index 0000000..4557b95 --- /dev/null +++ b/lib/codegen.ml @@ -0,0 +1,658 @@ +(** WebAssembly code generation for AffineScript *) + +open Ast + +(** WASM value types *) +type wasm_type = + | I32 + | I64 + | F32 + | F64 + | Funcref + | Externref +[@@deriving show] + +(** WASM instructions *) +type wasm_instr = + (* Constants *) + | I32Const of int + | I64Const of int64 + | F32Const of float + | F64Const of float + (* Local/Global variables *) + | LocalGet of int + | LocalSet of int + | LocalTee of int + | GlobalGet of int + | GlobalSet of int + (* Memory *) + | I32Load of int * int (** align, offset *) + | I32Store of int * int + | I64Load of int * int + | I64Store of int * int + | F32Load of int * int + | F32Store of int * int + | F64Load of int * int + | F64Store of int * int + | MemorySize + | MemoryGrow + (* Arithmetic (i32) *) + | I32Add | I32Sub | I32Mul | I32DivS | I32DivU | I32RemS | I32RemU + | I32And | I32Or | I32Xor | I32Shl | I32ShrS | I32ShrU + | I32Eqz | I32Eq | I32Ne | I32LtS | I32LtU | I32GtS | I32GtU + | I32LeS | I32LeU | I32GeS | I32GeU + (* Arithmetic (i64) *) + | I64Add | I64Sub | I64Mul | I64DivS | I64DivU | I64RemS | I64RemU + | I64And | I64Or | I64Xor | I64Shl | I64ShrS | I64ShrU + | I64Eqz | I64Eq | I64Ne | I64LtS | I64LtU | I64GtS | I64GtU + | I64LeS | I64LeU | I64GeS | I64GeU + (* Arithmetic (f32) *) + | F32Add | F32Sub | F32Mul | F32Div + | F32Eq | F32Ne | F32Lt | F32Gt | F32Le | F32Ge + | F32Neg | F32Abs | F32Sqrt | F32Ceil | F32Floor | F32Trunc + (* Arithmetic (f64) *) + | F64Add | F64Sub | F64Mul | F64Div + | F64Eq | F64Ne | F64Lt | F64Gt | F64Le | F64Ge + | F64Neg | F64Abs | F64Sqrt | F64Ceil | F64Floor | F64Trunc + (* Conversions *) + | I32WrapI64 + | I64ExtendI32S | I64ExtendI32U + | F32ConvertI32S | F32ConvertI32U | F32ConvertI64S | F32ConvertI64U + | F64ConvertI32S | F64ConvertI32U | F64ConvertI64S | F64ConvertI64U + | I32TruncF32S | I32TruncF32U | I32TruncF64S | I32TruncF64U + | I64TruncF32S | I64TruncF32U | I64TruncF64S | I64TruncF64U + | F32DemoteF64 | F64PromoteF32 + | I32ReinterpretF32 | I64ReinterpretF64 + | F32ReinterpretI32 | F64ReinterpretI64 + (* Control flow *) + | Unreachable + | Nop + | Block of wasm_type option * wasm_instr list + | Loop of wasm_type option * wasm_instr list + | If of wasm_type option * wasm_instr list * wasm_instr list option + | Br of int (** Branch to label *) + | BrIf of int (** Conditional branch *) + | BrTable of int list * int (** Branch table *) + | Return + | Call of int (** Call function by index *) + | CallIndirect of int (** Call function by table index *) + (* Parametric *) + | Drop + | Select +[@@deriving show] + +(** WASM function type *) +type wasm_func_type = { + ft_params: wasm_type list; + ft_results: wasm_type list; +} +[@@deriving show] + +(** WASM function *) +type wasm_func = { + fn_name: string; + fn_type: wasm_func_type; + fn_locals: wasm_type list; + fn_body: wasm_instr list; + fn_export: bool; +} +[@@deriving show] + +(** WASM global *) +type wasm_global = { + gl_name: string; + gl_type: wasm_type; + gl_mutable: bool; + gl_init: wasm_instr list; + gl_export: bool; +} +[@@deriving show] + +(** WASM import *) +type wasm_import = + | ImportFunc of string * string * wasm_func_type + | ImportGlobal of string * string * wasm_type * bool + | ImportMemory of string * string * int * int option + | ImportTable of string * string * int * int option +[@@deriving show] + +(** WASM export *) +type wasm_export = + | ExportFunc of string * int + | ExportGlobal of string * int + | ExportMemory of string + | ExportTable of string +[@@deriving show] + +(** WASM module *) +type wasm_module = { + mod_types: wasm_func_type list; + mod_imports: wasm_import list; + mod_funcs: wasm_func list; + mod_globals: wasm_global list; + mod_exports: wasm_export list; + mod_memory: (int * int option) option; (** min, max pages *) + mod_data: (int * string) list; (** offset, data *) +} + +(** Code generation context *) +type context = { + ctx_funcs: (string, int) Hashtbl.t; (** function name -> index *) + ctx_globals: (string, int) Hashtbl.t; (** global name -> index *) + ctx_locals: (string, int) Hashtbl.t; (** local name -> index *) + mutable ctx_local_count: int; + mutable ctx_func_count: int; + mutable ctx_global_count: int; + mutable ctx_label_depth: int; + ctx_types: wasm_func_type list ref; + ctx_module: wasm_module ref; +} + +let create_context () = { + ctx_funcs = Hashtbl.create 64; + ctx_globals = Hashtbl.create 32; + ctx_locals = Hashtbl.create 32; + ctx_local_count = 0; + ctx_func_count = 0; + ctx_global_count = 0; + ctx_label_depth = 0; + ctx_types = ref []; + ctx_module = ref { + mod_types = []; + mod_imports = []; + mod_funcs = []; + mod_globals = []; + mod_exports = []; + mod_memory = Some (1, Some 16); (* 1 page min, 16 pages max *) + mod_data = []; + }; +} + +(** Add a local variable, returning its index *) +let add_local ctx name ty = + let idx = ctx.ctx_local_count in + Hashtbl.replace ctx.ctx_locals name idx; + ctx.ctx_local_count <- idx + 1; + idx + +(** Look up local variable index *) +let get_local ctx name = + Hashtbl.find_opt ctx.ctx_locals name + +(** Add a function, returning its index *) +let add_func ctx name = + let idx = ctx.ctx_func_count in + Hashtbl.replace ctx.ctx_funcs name idx; + ctx.ctx_func_count <- idx + 1; + idx + +(** Look up function index *) +let get_func ctx name = + Hashtbl.find_opt ctx.ctx_funcs name + +(** Convert AffineScript type to WASM type *) +let rec type_to_wasm ty = + match ty with + | Types.TInt | Types.TNat -> I32 + | Types.TFloat -> F64 + | Types.TBool -> I32 + | Types.TChar -> I32 + | Types.TUnit -> I32 (* Unit = 0 *) + | Types.TString -> I32 (* Pointer to string data *) + | Types.TRef _ | Types.TMut _ | Types.TOwn _ -> I32 (* Pointers *) + | Types.TTuple _ -> I32 (* Pointer to tuple *) + | Types.TRecord _ -> I32 (* Pointer to record *) + | Types.TApp ("Array", _) -> I32 (* Pointer to array *) + | Types.TArrow _ -> I32 (* Function reference index *) + | _ -> I32 (* Default to i32 *) + +(** Compile literal to WASM *) +let compile_literal = function + | LitInt (n, _) -> [I32Const n] + | LitFloat (f, _) -> [F64Const f] + | LitBool (true, _) -> [I32Const 1] + | LitBool (false, _) -> [I32Const 0] + | LitChar (c, _) -> [I32Const (Char.code c)] + | LitString (_, _) -> [I32Const 0] (* TODO: string allocation *) + | LitUnit _ -> [I32Const 0] + +(** Compile binary operator *) +let compile_binop op = + match op with + | OpAdd -> [I32Add] + | OpSub -> [I32Sub] + | OpMul -> [I32Mul] + | OpDiv -> [I32DivS] + | OpMod -> [I32RemS] + | OpEq -> [I32Eq] + | OpNe -> [I32Ne] + | OpLt -> [I32LtS] + | OpLe -> [I32LeS] + | OpGt -> [I32GtS] + | OpGe -> [I32GeS] + | OpAnd -> [I32And] + | OpOr -> [I32Or] + | OpBitAnd -> [I32And] + | OpBitOr -> [I32Or] + | OpBitXor -> [I32Xor] + | OpShl -> [I32Shl] + | OpShr -> [I32ShrS] + +(** Compile unary operator *) +let compile_unop op = + match op with + | OpNeg -> [I32Const (-1); I32Mul] + | OpNot -> [I32Eqz] + | OpBitNot -> [I32Const (-1); I32Xor] + | OpRef -> [] (* Handled specially *) + | OpDeref -> [I32Load (2, 0)] (* Load from pointer *) + +(** Compile expression to WASM instructions *) +let rec compile_expr ctx expr = + match expr with + | ExprSpan (e, _) -> compile_expr ctx e + + | ExprLit lit -> compile_literal lit + + | ExprVar id -> + (match get_local ctx id.name with + | Some idx -> [LocalGet idx] + | None -> + match get_func ctx id.name with + | Some idx -> [I32Const idx] (* Function reference *) + | None -> [I32Const 0]) (* Unknown - return 0 *) + + | ExprLet { el_pat; el_value; el_body; _ } -> + let value_instrs = compile_expr ctx el_value in + let bind_instrs = compile_pattern_bind ctx el_pat in + let body_instrs = match el_body with + | Some body -> compile_expr ctx body + | None -> [I32Const 0] + in + value_instrs @ bind_instrs @ body_instrs + + | ExprIf { ei_cond; ei_then; ei_else } -> + let cond_instrs = compile_expr ctx ei_cond in + let then_instrs = compile_expr ctx ei_then in + let else_instrs = match ei_else with + | Some e -> Some (compile_expr ctx e) + | None -> Some [I32Const 0] + in + cond_instrs @ [If (Some I32, then_instrs, else_instrs)] + + | ExprMatch { em_scrutinee; em_arms } -> + (* Simplified: compile as nested if-else chain *) + let scrutinee_instrs = compile_expr ctx em_scrutinee in + let idx = add_local ctx "_match" I32 in + let set_instr = [LocalSet idx] in + let arms_instrs = compile_match_arms ctx idx em_arms in + scrutinee_instrs @ set_instr @ arms_instrs + + | ExprLambda _ -> + (* Lambdas compile to function indices *) + [I32Const 0] (* TODO: closure support *) + + | ExprApp (func, args) -> + let func_instrs = compile_expr ctx func in + let args_instrs = List.concat_map (compile_expr ctx) args in + (* For now, assume direct call via function index *) + args_instrs @ func_instrs @ [CallIndirect 0] (* TODO: proper call *) + + | ExprField (e, _field) -> + (* Record field access - compute offset and load *) + let record_instrs = compile_expr ctx e in + record_instrs @ [I32Load (2, 0)] (* TODO: field offset *) + + | ExprTupleIndex (e, idx) -> + let tuple_instrs = compile_expr ctx e in + tuple_instrs @ [I32Const (idx * 4); I32Add; I32Load (2, 0)] + + | ExprIndex (arr, idx) -> + let arr_instrs = compile_expr ctx arr in + let idx_instrs = compile_expr ctx idx in + arr_instrs @ idx_instrs @ [I32Const 4; I32Mul; I32Add; I32Load (2, 0)] + + | ExprTuple exprs -> + (* Allocate tuple on heap and store elements *) + let size = List.length exprs * 4 in + let alloc = [I32Const size; Call 0] (* Call malloc *) + in + let stores = List.mapi (fun i e -> + let elem_instrs = compile_expr ctx e in + [LocalGet 0] (* Base pointer *) + @ [I32Const (i * 4); I32Add] + @ elem_instrs + @ [I32Store (2, 0)] + ) exprs |> List.concat in + alloc @ stores @ [LocalGet 0] + + | ExprArray exprs -> + (* Similar to tuple but with length prefix *) + let len = List.length exprs in + let size = (len + 1) * 4 in + let alloc = [I32Const size; Call 0] in + let store_len = [ + LocalGet 0; + I32Const len; + I32Store (2, 0) + ] in + let stores = List.mapi (fun i e -> + let elem_instrs = compile_expr ctx e in + [LocalGet 0; I32Const ((i + 1) * 4); I32Add] + @ elem_instrs + @ [I32Store (2, 0)] + ) exprs |> List.concat in + alloc @ store_len @ stores @ [LocalGet 0] + + | ExprRecord { er_fields; _ } -> + (* Similar to tuple *) + let size = List.length er_fields * 4 in + let alloc = [I32Const size; Call 0] in + let stores = List.mapi (fun i (_, expr_opt) -> + let elem_instrs = match expr_opt with + | Some e -> compile_expr ctx e + | None -> [I32Const 0] + in + [LocalGet 0; I32Const (i * 4); I32Add] + @ elem_instrs + @ [I32Store (2, 0)] + ) er_fields |> List.concat in + alloc @ stores @ [LocalGet 0] + + | ExprRowRestrict (e, _) -> + compile_expr ctx e (* Simplified *) + + | ExprBinary (e1, op, e2) -> + let e1_instrs = compile_expr ctx e1 in + let e2_instrs = compile_expr ctx e2 in + let op_instrs = compile_binop op in + e1_instrs @ e2_instrs @ op_instrs + + | ExprUnary (op, e) -> + let e_instrs = compile_expr ctx e in + let op_instrs = compile_unop op in + e_instrs @ op_instrs + + | ExprBlock blk -> + compile_block ctx blk + + | ExprReturn e_opt -> + (match e_opt with + | Some e -> compile_expr ctx e @ [Return] + | None -> [I32Const 0; Return]) + + | ExprVariant (_, _) -> + [I32Const 0] (* TODO: variant encoding *) + + | ExprTry _ | ExprHandle _ | ExprResume _ | ExprUnsafe _ -> + [I32Const 0] (* TODO *) + +and compile_pattern_bind ctx pat = + match pat with + | PatWildcard _ -> [Drop] + | PatVar id -> + let idx = add_local ctx id.name I32 in + [LocalSet idx] + | PatLit _ -> [Drop] + | PatTuple pats -> + (* Value is pointer to tuple on stack *) + let base_idx = add_local ctx "_tuple_base" I32 in + [LocalSet base_idx] @ + List.concat (List.mapi (fun i p -> + [LocalGet base_idx; I32Const (i * 4); I32Add; I32Load (2, 0)] + @ compile_pattern_bind ctx p + ) pats) + | PatRecord _ | PatCon _ | PatOr _ | PatAs _ -> + [Drop] (* Simplified *) + +and compile_match_arms ctx scrutinee_idx = function + | [] -> [I32Const 0] + | [arm] -> + (* Last arm - just execute *) + compile_expr ctx arm.ma_body + | arm :: rest -> + (* Check pattern, if match execute body, else try next *) + let check_instrs = compile_pattern_check ctx scrutinee_idx arm.ma_pat in + let body_instrs = compile_expr ctx arm.ma_body in + let else_instrs = compile_match_arms ctx scrutinee_idx rest in + check_instrs @ [If (Some I32, body_instrs, Some else_instrs)] + +and compile_pattern_check _ctx scrutinee_idx = function + | PatWildcard _ -> [I32Const 1] (* Always matches *) + | PatVar _ -> [I32Const 1] (* Always matches *) + | PatLit lit -> + [LocalGet scrutinee_idx] @ + compile_literal lit @ + [I32Eq] + | _ -> [I32Const 1] (* Simplified *) + +and compile_block ctx { blk_stmts; blk_expr } = + let stmt_instrs = List.concat_map (compile_stmt ctx) blk_stmts in + let expr_instrs = match blk_expr with + | Some e -> compile_expr ctx e + | None -> [I32Const 0] + in + stmt_instrs @ expr_instrs + +and compile_stmt ctx = function + | StmtLet { sl_pat; sl_value; _ } -> + compile_expr ctx sl_value @ compile_pattern_bind ctx sl_pat + + | StmtExpr e -> + compile_expr ctx e @ [Drop] + + | StmtAssign (target, _, value) -> + (match target with + | ExprVar id -> + (match get_local ctx id.name with + | Some idx -> compile_expr ctx value @ [LocalSet idx] + | None -> []) + | ExprIndex (arr, idx) -> + compile_expr ctx arr @ + compile_expr ctx idx @ + [I32Const 4; I32Mul; I32Add] @ + compile_expr ctx value @ + [I32Store (2, 0)] + | _ -> []) + + | StmtWhile (cond, body) -> + ctx.ctx_label_depth <- ctx.ctx_label_depth + 1; + let cond_instrs = compile_expr ctx cond in + let body_instrs = compile_block ctx body in + ctx.ctx_label_depth <- ctx.ctx_label_depth - 1; + [Block (None, [ + Loop (None, + cond_instrs @ + [I32Eqz; BrIf 1] @ (* Exit if condition false *) + body_instrs @ + [Drop] @ (* Discard block result *) + [Br 0] (* Continue loop *) + ) + ])] + + | StmtFor (pat, iter, body) -> + (* Desugar to while loop *) + let iter_instrs = compile_expr ctx iter in + let idx_local = add_local ctx "_for_idx" I32 in + let arr_local = add_local ctx "_for_arr" I32 in + let len_local = add_local ctx "_for_len" I32 in + iter_instrs @ + [LocalSet arr_local] @ + [LocalGet arr_local; I32Load (2, 0); LocalSet len_local] @ (* Get length *) + [I32Const 0; LocalSet idx_local] @ + [Block (None, [ + Loop (None, + (* Check idx < len *) + [LocalGet idx_local; LocalGet len_local; I32GeS; BrIf 1] @ + (* Get element *) + [LocalGet arr_local; + LocalGet idx_local; I32Const 4; I32Mul; + I32Const 4; I32Add; (* Skip length field *) + I32Add; I32Load (2, 0)] @ + compile_pattern_bind ctx pat @ + compile_block ctx body @ + [Drop] @ + (* Increment idx *) + [LocalGet idx_local; I32Const 1; I32Add; LocalSet idx_local] @ + [Br 0] + ) + ])] + +(** Compile a function declaration *) +let compile_fn_decl ctx (decl: fn_decl) = + (* Reset locals *) + Hashtbl.clear ctx.ctx_locals; + ctx.ctx_local_count <- 0; + (* Add parameters as locals *) + let param_types = List.map (fun p -> + let _ = add_local ctx p.p_name.name I32 in + I32 (* Simplified: all params are i32 *) + ) decl.fd_params in + (* Compile body *) + let body_instrs = match decl.fd_body with + | FnBlock blk -> compile_block ctx blk + | FnExpr e -> compile_expr ctx e + in + let extra_locals = ctx.ctx_local_count - List.length decl.fd_params in + let local_types = List.init extra_locals (fun _ -> I32) in + { + fn_name = decl.fd_name.name; + fn_type = { + ft_params = param_types; + ft_results = [I32]; (* Simplified: all functions return i32 *) + }; + fn_locals = local_types; + fn_body = body_instrs; + fn_export = decl.fd_vis = Public; + } + +(** Compile a program to WASM module *) +let compile_program (prog: program) = + let ctx = create_context () in + (* First pass: register all functions *) + List.iter (fun decl -> + match decl with + | TopFn fd -> ignore (add_func ctx fd.fd_name.name) + | _ -> () + ) prog.prog_decls; + (* Second pass: compile functions *) + let funcs = List.filter_map (fun decl -> + match decl with + | TopFn fd -> Some (compile_fn_decl ctx fd) + | _ -> None + ) prog.prog_decls in + (* Build exports *) + let exports = List.filter_map (fun fn -> + if fn.fn_export then + match get_func ctx fn.fn_name with + | Some idx -> Some (ExportFunc (fn.fn_name, idx)) + | None -> None + else None + ) funcs in + { + mod_types = List.map (fun fn -> fn.fn_type) funcs; + mod_imports = [ + (* Import memory allocator *) + ImportFunc ("env", "malloc", { ft_params = [I32]; ft_results = [I32] }); + ImportFunc ("env", "print", { ft_params = [I32]; ft_results = [] }); + ]; + mod_funcs = funcs; + mod_globals = []; + mod_exports = exports @ [ExportMemory "memory"]; + mod_memory = Some (1, Some 256); + mod_data = []; + } + +(** Emit WASM binary *) +let emit_binary _module = + (* TODO: Implement actual WASM binary encoding *) + (* For now, return a placeholder *) + Bytes.empty + +(** Emit WASM text format (WAT) *) +let emit_wat module_ = + let buf = Buffer.create 4096 in + let add s = Buffer.add_string buf s in + let addln s = Buffer.add_string buf s; Buffer.add_char buf '\n' in + + addln "(module"; + + (* Types *) + List.iteri (fun i ft -> + add (Printf.sprintf " (type $t%d (func" i); + if ft.ft_params <> [] then begin + add " (param"; + List.iter (fun t -> + add (match t with I32 -> " i32" | I64 -> " i64" | F32 -> " f32" | F64 -> " f64" | _ -> " i32") + ) ft.ft_params; + add ")" + end; + if ft.ft_results <> [] then begin + add " (result"; + List.iter (fun t -> + add (match t with I32 -> " i32" | I64 -> " i64" | F32 -> " f32" | F64 -> " f64" | _ -> " i32") + ) ft.ft_results; + add ")" + end; + addln "))" + ) module_.mod_types; + + (* Imports *) + List.iter (fun imp -> + match imp with + | ImportFunc (mod_name, func_name, ft) -> + addln (Printf.sprintf " (import \"%s\" \"%s\" (func $%s_%s (param i32) (result i32)))" + mod_name func_name mod_name func_name) + | ImportMemory (mod_name, mem_name, min, max) -> + let max_str = match max with Some m -> Printf.sprintf " %d" m | None -> "" in + addln (Printf.sprintf " (import \"%s\" \"%s\" (memory %d%s))" + mod_name mem_name min max_str) + | _ -> () + ) module_.mod_imports; + + (* Memory *) + (match module_.mod_memory with + | Some (min, max) -> + let max_str = match max with Some m -> Printf.sprintf " %d" m | None -> "" in + addln (Printf.sprintf " (memory (export \"memory\") %d%s)" min max_str) + | None -> ()); + + (* Functions *) + List.iteri (fun i fn -> + add (Printf.sprintf " (func $%s (type $t%d)" fn.fn_name i); + List.iter (fun _ -> add " (local i32)") fn.fn_locals; + addln ""; + (* Emit body - simplified *) + List.iter (fun instr -> + add " "; + (match instr with + | I32Const n -> addln (Printf.sprintf "i32.const %d" n) + | LocalGet n -> addln (Printf.sprintf "local.get %d" n) + | LocalSet n -> addln (Printf.sprintf "local.set %d" n) + | I32Add -> addln "i32.add" + | I32Sub -> addln "i32.sub" + | I32Mul -> addln "i32.mul" + | I32DivS -> addln "i32.div_s" + | I32Eq -> addln "i32.eq" + | I32Ne -> addln "i32.ne" + | I32LtS -> addln "i32.lt_s" + | I32GtS -> addln "i32.gt_s" + | I32LeS -> addln "i32.le_s" + | I32GeS -> addln "i32.ge_s" + | Return -> addln "return" + | Drop -> addln "drop" + | _ -> addln "nop") + ) fn.fn_body; + addln " )" + ) module_.mod_funcs; + + (* Exports *) + List.iter (fun exp -> + match exp with + | ExportFunc (name, idx) -> + addln (Printf.sprintf " (export \"%s\" (func %d))" name idx) + | _ -> () + ) module_.mod_exports; + + addln ")"; + Buffer.contents buf diff --git a/lib/dependent.ml b/lib/dependent.ml new file mode 100644 index 0000000..f5eb5e7 --- /dev/null +++ b/lib/dependent.ml @@ -0,0 +1,349 @@ +(** Dependent types for AffineScript - length-indexed vectors and refinements *) + +open Types + +(** Nat-level expressions for type indices *) +type nat_expr = + | NLit of int (** Literal natural number *) + | NVar of string (** Nat variable *) + | NAdd of nat_expr * nat_expr (** n + m *) + | NSub of nat_expr * nat_expr (** n - m (saturating at 0) *) + | NMul of nat_expr * nat_expr (** n * m *) + | NMax of nat_expr * nat_expr (** max(n, m) *) + | NMin of nat_expr * nat_expr (** min(n, m) *) + | NLen of string (** Length of array variable *) + | NSizeof of ty (** Size of type in bytes *) +[@@deriving show, eq] + +(** Nat constraints *) +type nat_constraint = + | NCEq of nat_expr * nat_expr (** n = m *) + | NCLt of nat_expr * nat_expr (** n < m *) + | NCLe of nat_expr * nat_expr (** n <= m *) + | NCAnd of nat_constraint * nat_constraint + | NCOr of nat_constraint * nat_constraint + | NCNot of nat_constraint + | NCTrue + | NCFalse +[@@deriving show, eq] + +(** Dependent type context *) +type dep_context = { + dc_nat_vars: (string, nat_expr option) Hashtbl.t; (** name -> concrete value if known *) + dc_constraints: nat_constraint list; + dc_array_lengths: (string, nat_expr) Hashtbl.t; (** array var -> length *) +} + +let create_dep_context () = { + dc_nat_vars = Hashtbl.create 16; + dc_constraints = []; + dc_array_lengths = Hashtbl.create 16; +} + +(** Simplify a nat expression *) +let rec simplify_nat = function + | NLit n -> NLit n + | NVar v -> NVar v + | NAdd (n1, n2) -> + let n1' = simplify_nat n1 in + let n2' = simplify_nat n2 in + (match n1', n2' with + | NLit a, NLit b -> NLit (a + b) + | NLit 0, n | n, NLit 0 -> n + | _ -> NAdd (n1', n2')) + | NSub (n1, n2) -> + let n1' = simplify_nat n1 in + let n2' = simplify_nat n2 in + (match n1', n2' with + | NLit a, NLit b -> NLit (max 0 (a - b)) + | n, NLit 0 -> n + | _ -> NSub (n1', n2')) + | NMul (n1, n2) -> + let n1' = simplify_nat n1 in + let n2' = simplify_nat n2 in + (match n1', n2' with + | NLit a, NLit b -> NLit (a * b) + | NLit 0, _ | _, NLit 0 -> NLit 0 + | NLit 1, n | n, NLit 1 -> n + | _ -> NMul (n1', n2')) + | NMax (n1, n2) -> + let n1' = simplify_nat n1 in + let n2' = simplify_nat n2 in + (match n1', n2' with + | NLit a, NLit b -> NLit (max a b) + | _ -> NMax (n1', n2')) + | NMin (n1, n2) -> + let n1' = simplify_nat n1 in + let n2' = simplify_nat n2 in + (match n1', n2' with + | NLit a, NLit b -> NLit (min a b) + | _ -> NMin (n1', n2')) + | NLen v -> NLen v + | NSizeof t -> NSizeof t + +(** Substitute nat variables *) +let rec subst_nat_var name value expr = + match expr with + | NLit n -> NLit n + | NVar v when v = name -> value + | NVar v -> NVar v + | NAdd (n1, n2) -> NAdd (subst_nat_var name value n1, subst_nat_var name value n2) + | NSub (n1, n2) -> NSub (subst_nat_var name value n1, subst_nat_var name value n2) + | NMul (n1, n2) -> NMul (subst_nat_var name value n1, subst_nat_var name value n2) + | NMax (n1, n2) -> NMax (subst_nat_var name value n1, subst_nat_var name value n2) + | NMin (n1, n2) -> NMin (subst_nat_var name value n1, subst_nat_var name value n2) + | NLen v -> NLen v + | NSizeof t -> NSizeof t + +(** Check if a nat expression is concrete (no variables) *) +let rec is_concrete = function + | NLit _ -> true + | NVar _ -> false + | NAdd (n1, n2) | NSub (n1, n2) | NMul (n1, n2) | NMax (n1, n2) | NMin (n1, n2) -> + is_concrete n1 && is_concrete n2 + | NLen _ -> false + | NSizeof _ -> true + +(** Evaluate concrete nat expression *) +let rec eval_nat = function + | NLit n -> Some n + | NVar _ -> None + | NAdd (n1, n2) -> + Option.bind (eval_nat n1) (fun a -> + Option.map (fun b -> a + b) (eval_nat n2)) + | NSub (n1, n2) -> + Option.bind (eval_nat n1) (fun a -> + Option.map (fun b -> max 0 (a - b)) (eval_nat n2)) + | NMul (n1, n2) -> + Option.bind (eval_nat n1) (fun a -> + Option.map (fun b -> a * b) (eval_nat n2)) + | NMax (n1, n2) -> + Option.bind (eval_nat n1) (fun a -> + Option.map (fun b -> max a b) (eval_nat n2)) + | NMin (n1, n2) -> + Option.bind (eval_nat n1) (fun a -> + Option.map (fun b -> min a b) (eval_nat n2)) + | NLen _ -> None + | NSizeof _ -> Some 4 (* Default size *) + +(** Check nat constraint (when possible) *) +let rec check_constraint = function + | NCTrue -> Some true + | NCFalse -> Some false + | NCEq (n1, n2) -> + Option.bind (eval_nat n1) (fun a -> + Option.map (fun b -> a = b) (eval_nat n2)) + | NCLt (n1, n2) -> + Option.bind (eval_nat n1) (fun a -> + Option.map (fun b -> a < b) (eval_nat n2)) + | NCLe (n1, n2) -> + Option.bind (eval_nat n1) (fun a -> + Option.map (fun b -> a <= b) (eval_nat n2)) + | NCAnd (c1, c2) -> + Option.bind (check_constraint c1) (fun a -> + Option.map (fun b -> a && b) (check_constraint c2)) + | NCOr (c1, c2) -> + Option.bind (check_constraint c1) (fun a -> + Option.map (fun b -> a || b) (check_constraint c2)) + | NCNot c -> + Option.map not (check_constraint c) + +(** Dependent vector type: Vec[n, T] *) +type dep_vec = { + dv_len: nat_expr; + dv_elem: ty; +} +[@@deriving show] + +(** Create a dependent vector type *) +let vec_type len elem = + TApp ("Vec", [TRigid (Printf.sprintf "%s" (show_nat_expr len)); elem]) + +(** Extract length from Vec type if possible *) +let vec_length = function + | TApp ("Vec", [len_ty; _]) -> + (match len_ty with + | TRigid s -> + (* Try to parse as nat literal *) + (try Some (NLit (int_of_string s)) + with _ -> Some (NVar s)) + | _ -> None) + | _ -> None + +(** Dependent function type: (n: Nat) -> Vec[n, T] -> Vec[n, T] *) +type dep_arrow = { + da_nat_params: string list; (** Nat parameters *) + da_type_params: string list; (** Type parameters *) + da_param_ty: ty; (** Parameter type *) + da_return_ty: ty; (** Return type *) + da_constraints: nat_constraint list; (** Constraints on nat params *) +} +[@@deriving show] + +(** Matrix type: Matrix[m, n, T] *) +let matrix_type m n elem = + TApp ("Matrix", [ + TRigid (show_nat_expr m); + TRigid (show_nat_expr n); + elem + ]) + +(** Refinement type: { x: Int | x > 0 } *) +type refined_type = { + rt_base: ty; + rt_var: string; + rt_pred: nat_constraint; +} +[@@deriving show] + +(** Create positive integer type *) +let pos_int_type = + TRefined (TInt, RTrue) (* Simplified - would need proper predicate *) + +(** Create bounded integer type: { x: Int | lo <= x && x < hi } *) +let bounded_int lo hi = + let constraint_ = NCAnd ( + NCLe (NLit lo, NVar "x"), + NCLt (NVar "x", NLit hi) + ) in + { rt_base = TInt; rt_var = "x"; rt_pred = constraint_ } + +(** Index type for arrays: { i: Nat | i < len(arr) } *) +let index_type arr_name = + let constraint_ = NCLt (NVar "i", NLen arr_name) in + { rt_base = TNat; rt_var = "i"; rt_pred = constraint_ } + +(** Dependent type errors *) +type dep_error = + | NatMismatch of nat_expr * nat_expr + | ConstraintViolation of nat_constraint + | IndexOutOfBounds of nat_expr * nat_expr (** index, length *) + | LengthMismatch of string * nat_expr * nat_expr + | UnsolvedConstraint of nat_constraint +[@@deriving show] + +exception Dep_error of dep_error + +(** Check array index bounds *) +let check_bounds ctx index_expr len_expr = + let constraint_ = NCLt (index_expr, len_expr) in + match check_constraint constraint_ with + | Some true -> Ok () + | Some false -> Error (IndexOutOfBounds (index_expr, len_expr)) + | None -> + (* Add to constraints for later solving *) + ctx.dc_constraints <- constraint_ :: ctx.dc_constraints; + Ok () + +(** Unify two nat expressions *) +let rec unify_nat n1 n2 = + let n1' = simplify_nat n1 in + let n2' = simplify_nat n2 in + if equal_nat_expr n1' n2' then + Ok [] + else match n1', n2' with + | NVar v, n | n, NVar v -> + Ok [(v, n)] + | NAdd (a1, a2), NAdd (b1, b2) -> + Result.bind (unify_nat a1 b1) (fun s1 -> + Result.map (fun s2 -> s1 @ s2) (unify_nat a2 b2)) + | NLit a, NLit b when a = b -> Ok [] + | _ -> Error (NatMismatch (n1', n2')) + +(** Check that a dependent function call is valid *) +let check_dep_call ctx dep_fn arg_nats = + (* Substitute nat arguments into constraints and check *) + let subst_constraints = List.fold_left2 (fun constrs param_name arg -> + List.map (fun c -> + let rec subst_in_constraint = function + | NCEq (n1, n2) -> NCEq (subst_nat_var param_name arg n1, subst_nat_var param_name arg n2) + | NCLt (n1, n2) -> NCLt (subst_nat_var param_name arg n1, subst_nat_var param_name arg n2) + | NCLe (n1, n2) -> NCLe (subst_nat_var param_name arg n1, subst_nat_var param_name arg n2) + | NCAnd (c1, c2) -> NCAnd (subst_in_constraint c1, subst_in_constraint c2) + | NCOr (c1, c2) -> NCOr (subst_in_constraint c1, subst_in_constraint c2) + | NCNot c -> NCNot (subst_in_constraint c) + | NCTrue -> NCTrue + | NCFalse -> NCFalse + in + subst_in_constraint c + ) constrs + ) dep_fn.da_constraints dep_fn.da_nat_params arg_nats in + (* Check all constraints *) + List.iter (fun c -> + match check_constraint c with + | Some false -> raise (Dep_error (ConstraintViolation c)) + | Some true -> () + | None -> ctx.dc_constraints <- c :: ctx.dc_constraints + ) subst_constraints + +(** Common dependent type signatures *) + +(** append: Vec[n, T] -> Vec[m, T] -> Vec[n + m, T] *) +let vec_append_sig = + let n = NVar "n" in + let m = NVar "m" in + { + da_nat_params = ["n"; "m"]; + da_type_params = ["T"]; + da_param_ty = TTuple [vec_type n (TRigid "T"); vec_type m (TRigid "T")]; + da_return_ty = vec_type (NAdd (n, m)) (TRigid "T"); + da_constraints = []; + } + +(** head: Vec[n + 1, T] -> T *) +let vec_head_sig = + let n = NVar "n" in + { + da_nat_params = ["n"]; + da_type_params = ["T"]; + da_param_ty = vec_type (NAdd (n, NLit 1)) (TRigid "T"); + da_return_ty = TRigid "T"; + da_constraints = []; + } + +(** tail: Vec[n + 1, T] -> Vec[n, T] *) +let vec_tail_sig = + let n = NVar "n" in + { + da_nat_params = ["n"]; + da_type_params = ["T"]; + da_param_ty = vec_type (NAdd (n, NLit 1)) (TRigid "T"); + da_return_ty = vec_type n (TRigid "T"); + da_constraints = []; + } + +(** zip: Vec[n, A] -> Vec[n, B] -> Vec[n, (A, B)] *) +let vec_zip_sig = + let n = NVar "n" in + { + da_nat_params = ["n"]; + da_type_params = ["A"; "B"]; + da_param_ty = TTuple [vec_type n (TRigid "A"); vec_type n (TRigid "B")]; + da_return_ty = vec_type n (TTuple [TRigid "A"; TRigid "B"]); + da_constraints = []; + } + +(** transpose: Matrix[m, n, T] -> Matrix[n, m, T] *) +let matrix_transpose_sig = + let m = NVar "m" in + let n = NVar "n" in + { + da_nat_params = ["m"; "n"]; + da_type_params = ["T"]; + da_param_ty = matrix_type m n (TRigid "T"); + da_return_ty = matrix_type n m (TRigid "T"); + da_constraints = []; + } + +(** matmul: Matrix[m, n, T] -> Matrix[n, p, T] -> Matrix[m, p, T] *) +let matrix_mul_sig = + let m = NVar "m" in + let n = NVar "n" in + let p = NVar "p" in + { + da_nat_params = ["m"; "n"; "p"]; + da_type_params = ["T"]; + da_param_ty = TTuple [matrix_type m n (TRigid "T"); matrix_type n p (TRigid "T")]; + da_return_ty = matrix_type m p (TRigid "T"); + da_constraints = []; + } diff --git a/lib/effects.ml b/lib/effects.ml new file mode 100644 index 0000000..1fcda33 --- /dev/null +++ b/lib/effects.ml @@ -0,0 +1,379 @@ +(** Algebraic effects and handlers for AffineScript *) + +open Ast +open Value + +(** Effect operation *) +type effect_op = { + eo_name: string; + eo_effect: string; + eo_params: Types.ty list; + eo_result: Types.ty; +} +[@@deriving show] + +(** Effect definition *) +type effect_def = { + ef_name: string; + ef_type_params: string list; + ef_ops: effect_op list; +} +[@@deriving show] + +(** Effect handler clause *) +type handler_clause = + | HReturn of string * expr (** return x -> e *) + | HOp of string * string list * string * expr (** op(args) k -> e *) +[@@deriving show] + +(** Effect handler *) +type handler = { + h_effect: string; + h_clauses: handler_clause list; +} +[@@deriving show] + +(** Continuation representation *) +type continuation = { + k_env: env; + k_expr: expr; + k_hole: string; (** Variable to substitute result into *) +} + +(** Effect stack frame *) +type effect_frame = { + ef_handler: handler; + ef_cont: continuation; +} + +(** Effect runtime state *) +type effect_state = { + es_stack: effect_frame list; + es_suspended: (string * t * continuation) list; (** Suspended computations *) +} + +let create_effect_state () = { + es_stack = []; + es_suspended = []; +} + +(** Built-in effects *) + +(** State effect *) +let state_effect = { + ef_name = "State"; + ef_type_params = ["S"]; + ef_ops = [ + { eo_name = "get"; eo_effect = "State"; eo_params = []; eo_result = Types.TRigid "S" }; + { eo_name = "put"; eo_effect = "State"; eo_params = [Types.TRigid "S"]; eo_result = Types.TUnit }; + ]; +} + +(** Exception effect *) +let exn_effect = { + ef_name = "Exn"; + ef_type_params = ["E"]; + ef_ops = [ + { eo_name = "throw"; eo_effect = "Exn"; eo_params = [Types.TRigid "E"]; eo_result = Types.TNever }; + ]; +} + +(** Async effect *) +let async_effect = { + ef_name = "Async"; + ef_type_params = []; + ef_ops = [ + { eo_name = "await"; eo_effect = "Async"; eo_params = [Types.TApp ("Future", [Types.TRigid "T"])]; eo_result = Types.TRigid "T" }; + { eo_name = "spawn"; eo_effect = "Async"; eo_params = [Types.TArrow (Types.TUnit, Types.TRigid "T", Types.EEmpty)]; eo_result = Types.TApp ("Future", [Types.TRigid "T"]) }; + { eo_name = "yield_"; eo_effect = "Async"; eo_params = []; eo_result = Types.TUnit }; + ]; +} + +(** Reader effect *) +let reader_effect = { + ef_name = "Reader"; + ef_type_params = ["R"]; + ef_ops = [ + { eo_name = "ask"; eo_effect = "Reader"; eo_params = []; eo_result = Types.TRigid "R" }; + { eo_name = "local"; eo_effect = "Reader"; + eo_params = [Types.TArrow (Types.TRigid "R", Types.TRigid "R", Types.EEmpty); Types.TArrow (Types.TUnit, Types.TRigid "A", Types.ECon ("Reader", []))]; + eo_result = Types.TRigid "A" }; + ]; +} + +(** Writer effect *) +let writer_effect = { + ef_name = "Writer"; + ef_type_params = ["W"]; + ef_ops = [ + { eo_name = "tell"; eo_effect = "Writer"; eo_params = [Types.TRigid "W"]; eo_result = Types.TUnit }; + ]; +} + +(** Choice/NonDet effect *) +let choice_effect = { + ef_name = "Choice"; + ef_type_params = []; + ef_ops = [ + { eo_name = "choose"; eo_effect = "Choice"; eo_params = []; eo_result = Types.TBool }; + { eo_name = "fail"; eo_effect = "Choice"; eo_params = []; eo_result = Types.TNever }; + ]; +} + +(** Console IO effect *) +let console_effect = { + ef_name = "Console"; + ef_type_params = []; + ef_ops = [ + { eo_name = "print"; eo_effect = "Console"; eo_params = [Types.TString]; eo_result = Types.TUnit }; + { eo_name = "read_line"; eo_effect = "Console"; eo_params = []; eo_result = Types.TString }; + ]; +} + +(** All built-in effects *) +let builtin_effects = [ + state_effect; + exn_effect; + async_effect; + reader_effect; + writer_effect; + choice_effect; + console_effect; +] + +(** Effect operation result *) +type op_result = + | OpReturn of t (** Operation completed with value *) + | OpSuspend of string * t list * continuation (** Operation suspended *) + | OpAbort of t (** Handler aborted with value *) + +(** Find handler for effect operation *) +let find_handler state effect_name op_name = + List.find_opt (fun frame -> + frame.ef_handler.h_effect = effect_name && + List.exists (function + | HOp (name, _, _, _) -> name = op_name + | _ -> false + ) frame.ef_handler.h_clauses + ) state.es_stack + +(** Perform an effect operation *) +let perform state env effect_name op_name args cont = + match find_handler state effect_name op_name with + | None -> + (* No handler - this is an error *) + failwith (Printf.sprintf "Unhandled effect operation: %s.%s" effect_name op_name) + | Some frame -> + (* Find the operation clause *) + let clause = List.find_map (function + | HOp (name, params, k_name, body) when name = op_name -> + Some (params, k_name, body) + | _ -> None + ) frame.ef_handler.h_clauses in + match clause with + | None -> + failwith (Printf.sprintf "Missing clause for operation: %s" op_name) + | Some (params, k_name, body) -> + (* Bind parameters and continuation *) + let handler_env = child_env env in + List.iter2 (fun param arg -> + bind handler_env param arg ~mutable_:false ~linear:false + ) params args; + (* Bind continuation as a function *) + let k_value = VBuiltin ("continuation", fun resume_args -> + match resume_args with + | [result] -> + (* Resume the suspended computation *) + let resume_env = child_env cont.k_env in + bind resume_env cont.k_hole result ~mutable_:false ~linear:false; + Eval.eval resume_env cont.k_expr + | _ -> failwith "Continuation expects exactly one argument" + ) in + bind handler_env k_name k_value ~mutable_:false ~linear:false; + (* Evaluate handler body *) + OpReturn (Eval.eval handler_env body) + +(** Install a handler and run computation *) +let handle state handler computation env = + let frame = { + ef_handler = handler; + ef_cont = { k_env = env; k_expr = computation; k_hole = "_result" }; + } in + let state' = { state with es_stack = frame :: state.es_stack } in + try + let result = Eval.eval env computation in + (* Computation completed - run return clause *) + let return_clause = List.find_map (function + | HReturn (x, body) -> Some (x, body) + | _ -> None + ) handler.h_clauses in + match return_clause with + | Some (x, body) -> + let return_env = child_env env in + bind return_env x result ~mutable_:false ~linear:false; + Eval.eval return_env body + | None -> result + with + | Eval.Runtime_error _ as e -> + (* Pop handler and re-raise *) + raise e + +(** State effect handler implementation *) +let state_handler init_state = + let state_ref = ref init_state in + { + h_effect = "State"; + h_clauses = [ + HReturn ("x", ExprVar { name = "x"; span = Span.dummy }); + (* get() k -> k(!state_ref) *) + (* put(s) k -> state_ref := s; k(()) *) + ]; + } + +(** Exception handler implementation *) +let exn_handler on_throw = + { + h_effect = "Exn"; + h_clauses = [ + HReturn ("x", ExprLit (LitUnit Span.dummy)); (* wrap in Ok *) + (* throw(e) k -> on_throw(e) *) + ]; + } + +(** Reader handler implementation *) +let reader_handler env_value = + { + h_effect = "Reader"; + h_clauses = [ + HReturn ("x", ExprVar { name = "x"; span = Span.dummy }); + (* ask() k -> k(env_value) *) + (* local(f, m) k -> handle[Reader(f(env_value))] { m() } |> k *) + ]; + } + +(** Writer handler implementation *) +let writer_handler = + { + h_effect = "Writer"; + h_clauses = [ + HReturn ("x", ExprVar { name = "x"; span = Span.dummy }); (* (x, []) *) + (* tell(w) k -> let (a, ws) = k(()) in (a, w :: ws) *) + ]; + } + +(** Choice handler implementation - returns list of all results *) +let choice_list_handler = + { + h_effect = "Choice"; + h_clauses = [ + HReturn ("x", ExprVar { name = "x"; span = Span.dummy }); (* [x] *) + (* choose() k -> k(true) @ k(false) *) + (* fail() k -> [] *) + ]; + } + +(** Async runtime state *) +type async_state = { + mutable as_ready: (int * (unit -> t)) list; (** Ready tasks *) + mutable as_waiting: (int * (t -> unit -> t)) list; (** Waiting for value *) + mutable as_next_id: int; +} + +let create_async_state () = { + as_ready = []; + as_waiting = []; + as_next_id = 0; +} + +(** Future value *) +type future_state = + | Pending + | Resolved of t + +type future = { + f_id: int; + mutable f_state: future_state; + mutable f_waiters: (t -> unit) list; +} + +(** Async effect interpreter *) +let run_async computation env = + let async_state = create_async_state () in + let futures : (int, future) Hashtbl.t = Hashtbl.create 16 in + + let make_future () = + let id = async_state.as_next_id in + async_state.as_next_id <- id + 1; + let f = { f_id = id; f_state = Pending; f_waiters = [] } in + Hashtbl.replace futures id f; + f + in + + let resolve_future f value = + f.f_state <- Resolved value; + List.iter (fun waiter -> waiter value) f.f_waiters; + f.f_waiters <- [] + in + + (* Main scheduler loop *) + let rec run_scheduler () = + match async_state.as_ready with + | [] -> () (* All done *) + | (id, task) :: rest -> + async_state.as_ready <- rest; + let _ = task () in + run_scheduler () + in + + (* Start main computation *) + async_state.as_ready <- [(0, fun () -> Eval.eval env computation)]; + run_scheduler (); + VUnit + +(** Effect type checking *) + +(** Effect row - set of effects *) +type effect_row = + | EffPure (** No effects *) + | EffVar of string (** Effect variable *) + | EffCons of string * effect_row (** E | r *) +[@@deriving show] + +(** Check if effect is in row *) +let rec effect_in_row eff = function + | EffPure -> false + | EffVar _ -> true (* Unknown row might contain it *) + | EffCons (e, rest) -> e = eff || effect_in_row eff rest + +(** Subtract effect from row *) +let rec subtract_effect eff = function + | EffPure -> EffPure + | EffVar v -> EffVar v (* Can't remove from unknown row *) + | EffCons (e, rest) -> + if e = eff then rest + else EffCons (e, subtract_effect eff rest) + +(** Unify effect rows *) +let rec unify_effects r1 r2 = + match r1, r2 with + | EffPure, EffPure -> Ok [] + | EffVar v, r | r, EffVar v -> Ok [(v, r)] + | EffCons (e1, rest1), EffCons (e2, rest2) when e1 = e2 -> + unify_effects rest1 rest2 + | EffCons (e1, rest1), r2 -> + if effect_in_row e1 r2 then + unify_effects rest1 (subtract_effect e1 r2) + else + Error (Printf.sprintf "Effect %s not in row" e1) + | _ -> Error "Cannot unify effect rows" + +(** Check that computation's effects are handled *) +let check_handled handlers comp_effects = + let handled_effects = List.map (fun h -> h.h_effect) handlers in + let rec check = function + | EffPure -> true + | EffVar _ -> true (* Polymorphic - assumed ok *) + | EffCons (e, rest) -> + List.mem e handled_effects && check rest + in + check comp_effects diff --git a/lib/eval.ml b/lib/eval.ml new file mode 100644 index 0000000..b61f4ad --- /dev/null +++ b/lib/eval.ml @@ -0,0 +1,544 @@ +(** Tree-walking interpreter for AffineScript *) + +open Ast +open Value + +(** Interpreter errors *) +exception Runtime_error of string * Span.t option + +let error ?span msg = raise (Runtime_error (msg, span)) + +(** Get span from expression *) +let expr_span = function + | ExprSpan (_, span) -> Some span + | ExprLit (LitInt (_, span)) -> Some span + | ExprLit (LitFloat (_, span)) -> Some span + | ExprLit (LitBool (_, span)) -> Some span + | ExprLit (LitChar (_, span)) -> Some span + | ExprLit (LitString (_, span)) -> Some span + | ExprLit (LitUnit span) -> Some span + | ExprVar { span; _ } -> Some span + | _ -> None + +(** Evaluate a literal *) +let eval_literal = function + | LitInt (i, _) -> VInt i + | LitFloat (f, _) -> VFloat f + | LitBool (b, _) -> VBool b + | LitChar (c, _) -> VChar c + | LitString (s, _) -> VString s + | LitUnit _ -> VUnit + +(** Binary operation on integers *) +let int_binop op v1 v2 = + match v1, v2 with + | VInt a, VInt b -> ( + match op with + | OpAdd -> Ok (VInt (a + b)) + | OpSub -> Ok (VInt (a - b)) + | OpMul -> Ok (VInt (a * b)) + | OpDiv -> + if b = 0 then Error "Division by zero" + else Ok (VInt (a / b)) + | OpMod -> + if b = 0 then Error "Modulo by zero" + else Ok (VInt (a mod b)) + | OpEq -> Ok (VBool (a = b)) + | OpNe -> Ok (VBool (a <> b)) + | OpLt -> Ok (VBool (a < b)) + | OpLe -> Ok (VBool (a <= b)) + | OpGt -> Ok (VBool (a > b)) + | OpGe -> Ok (VBool (a >= b)) + | OpBitAnd -> Ok (VInt (a land b)) + | OpBitOr -> Ok (VInt (a lor b)) + | OpBitXor -> Ok (VInt (a lxor b)) + | OpShl -> Ok (VInt (a lsl b)) + | OpShr -> Ok (VInt (a asr b)) + | OpAnd | OpOr -> Error "Boolean operator on integers" + ) + | _ -> Error "Type mismatch in integer operation" + +(** Binary operation on floats *) +let float_binop op v1 v2 = + match v1, v2 with + | VFloat a, VFloat b -> ( + match op with + | OpAdd -> Ok (VFloat (a +. b)) + | OpSub -> Ok (VFloat (a -. b)) + | OpMul -> Ok (VFloat (a *. b)) + | OpDiv -> Ok (VFloat (a /. b)) + | OpEq -> Ok (VBool (a = b)) + | OpNe -> Ok (VBool (a <> b)) + | OpLt -> Ok (VBool (a < b)) + | OpLe -> Ok (VBool (a <= b)) + | OpGt -> Ok (VBool (a > b)) + | OpGe -> Ok (VBool (a >= b)) + | _ -> Error "Invalid operation on floats" + ) + | VInt a, VFloat b -> float_binop op (VFloat (Float.of_int a)) (VFloat b) + | VFloat a, VInt b -> float_binop op (VFloat a) (VFloat (Float.of_int b)) + | _ -> Error "Type mismatch in float operation" + +(** Boolean binary operations *) +let bool_binop op v1 v2 = + match v1, v2 with + | VBool a, VBool b -> ( + match op with + | OpAnd -> Ok (VBool (a && b)) + | OpOr -> Ok (VBool (a || b)) + | OpEq -> Ok (VBool (a = b)) + | OpNe -> Ok (VBool (a <> b)) + | _ -> Error "Invalid operation on booleans" + ) + | _ -> Error "Type mismatch in boolean operation" + +(** String operations *) +let string_binop op v1 v2 = + match v1, v2 with + | VString a, VString b -> ( + match op with + | OpAdd -> Ok (VString (a ^ b)) + | OpEq -> Ok (VBool (a = b)) + | OpNe -> Ok (VBool (a <> b)) + | OpLt -> Ok (VBool (a < b)) + | OpLe -> Ok (VBool (a <= b)) + | OpGt -> Ok (VBool (a > b)) + | OpGe -> Ok (VBool (a >= b)) + | _ -> Error "Invalid operation on strings" + ) + | _ -> Error "Type mismatch in string operation" + +(** Evaluate binary expression *) +let eval_binop op v1 v2 = + match v1, v2 with + | VInt _, VInt _ -> int_binop op v1 v2 + | VFloat _, _ | _, VFloat _ -> float_binop op v1 v2 + | VBool _, VBool _ -> bool_binop op v1 v2 + | VString _, VString _ -> string_binop op v1 v2 + | _ -> + match op with + | OpEq -> Ok (VBool (Value.equal v1 v2)) + | OpNe -> Ok (VBool (not (Value.equal v1 v2))) + | _ -> Error (Printf.sprintf "Cannot apply operator to %s and %s" + (Value.show v1) (Value.show v2)) + +(** Evaluate unary expression *) +let eval_unop op v = + match op, v with + | OpNeg, VInt i -> Ok (VInt (-i)) + | OpNeg, VFloat f -> Ok (VFloat (-.f)) + | OpNot, VBool b -> Ok (VBool (not b)) + | OpBitNot, VInt i -> Ok (VInt (lnot i)) + | OpRef, v -> Ok (VRef (ref v)) + | OpDeref, VRef r -> Ok (!r) + | OpDeref, _ -> Error "Cannot dereference non-reference value" + | _ -> Error (Printf.sprintf "Invalid unary operation on %s" (Value.show v)) + +(** Match a pattern against a value, returning bindings if successful *) +let rec match_pattern pat value : (string * t * bool) list option = + match pat with + | PatWildcard _ -> Some [] + | PatVar id -> Some [(id.name, value, false)] + | PatLit lit -> + let lit_val = eval_literal lit in + if Value.equal lit_val value then Some [] else None + | PatTuple pats -> ( + match value with + | VTuple values when List.length pats = List.length values -> + let bindings = List.map2 match_pattern pats values in + if List.for_all Option.is_some bindings then + Some (List.concat_map Option.get bindings) + else None + | _ -> None + ) + | PatRecord (fields, has_rest) -> ( + match value with + | VRecord rec_fields -> + let match_field (id, pat_opt) = + match List.assoc_opt id.name rec_fields with + | Some v -> ( + match pat_opt with + | None -> Some [(id.name, v, false)] + | Some p -> match_pattern p v + ) + | None -> if has_rest then Some [] else None + in + let bindings = List.map match_field fields in + if List.for_all Option.is_some bindings then + Some (List.concat_map Option.get bindings) + else None + | _ -> None + ) + | PatCon (id, pats) -> ( + match value with + | VVariant (_, variant, args) when id.name = variant && List.length pats = List.length args -> + let bindings = List.map2 match_pattern pats args in + if List.for_all Option.is_some bindings then + Some (List.concat_map Option.get bindings) + else None + | _ -> None + ) + | PatOr (p1, p2) -> ( + match match_pattern p1 value with + | Some bindings -> Some bindings + | None -> match_pattern p2 value + ) + | PatAs (id, p) -> ( + match match_pattern p value with + | Some bindings -> Some ((id.name, value, false) :: bindings) + | None -> None + ) + +(** Control flow exceptions *) +exception Return_exn of t +exception Break_exn +exception Continue_exn + +(** Main evaluation function *) +let rec eval env expr = + match expr with + | ExprSpan (e, _) -> eval env e + + | ExprLit lit -> eval_literal lit + + | ExprVar id -> ( + match Value.consume env id.name with + | Ok v -> v + | Error msg -> error ~span:id.span msg + ) + + | ExprLet { el_mut; el_pat; el_value; el_body; _ } -> + let value = eval env el_value in + let bindings = match match_pattern el_pat value with + | Some b -> b + | None -> error "Pattern match failed in let binding" + in + List.iter (fun (name, v, _) -> + Value.bind env name v ~mutable_:el_mut ~linear:false + ) bindings; + (match el_body with + | Some body -> eval env body + | None -> VUnit) + + | ExprIf { ei_cond; ei_then; ei_else } -> + let cond_val = eval env ei_cond in + (match Value.to_bool cond_val with + | Ok true -> eval env ei_then + | Ok false -> ( + match ei_else with + | Some e -> eval env e + | None -> VUnit + ) + | Error msg -> error msg) + + | ExprMatch { em_scrutinee; em_arms } -> + let value = eval env em_scrutinee in + let rec try_arms = function + | [] -> error "No matching pattern in match expression" + | arm :: rest -> + match match_pattern arm.ma_pat value with + | None -> try_arms rest + | Some bindings -> + let match_env = child_env env in + List.iter (fun (name, v, _) -> + Value.bind match_env name v ~mutable_:false ~linear:false + ) bindings; + (* Check guard if present *) + let guard_ok = match arm.ma_guard with + | None -> true + | Some guard -> + match Value.to_bool (eval match_env guard) with + | Ok b -> b + | Error _ -> false + in + if guard_ok then eval match_env arm.ma_body + else try_arms rest + in + try_arms em_arms + + | ExprLambda { elam_params; elam_body; _ } -> + VClosure { params = elam_params; body = elam_body; env } + + | ExprApp (func, args) -> + let func_val = eval env func in + let arg_vals = List.map (eval env) args in + apply func_val arg_vals + + | ExprField (e, field) -> + let value = eval env e in + (match value with + | VRecord fields -> ( + match List.assoc_opt field.name fields with + | Some v -> v + | None -> error ~span:field.span (Printf.sprintf "Field '%s' not found" field.name) + ) + | _ -> error "Cannot access field on non-record value") + + | ExprTupleIndex (e, idx) -> + let value = eval env e in + (match value with + | VTuple values -> + if idx >= 0 && idx < List.length values then + List.nth values idx + else + error (Printf.sprintf "Tuple index %d out of bounds" idx) + | _ -> error "Cannot index non-tuple value") + + | ExprIndex (e, idx_expr) -> + let value = eval env e in + let idx_val = eval env idx_expr in + (match value, idx_val with + | VArray arr, VInt i -> + if i >= 0 && i < Array.length arr then arr.(i) + else error (Printf.sprintf "Array index %d out of bounds" i) + | VString s, VInt i -> + if i >= 0 && i < String.length s then VChar (String.get s i) + else error (Printf.sprintf "String index %d out of bounds" i) + | _ -> error "Invalid indexing operation") + + | ExprTuple exprs -> + VTuple (List.map (eval env) exprs) + + | ExprArray exprs -> + VArray (Array.of_list (List.map (eval env) exprs)) + + | ExprRecord { er_fields; er_spread } -> + let base_fields = match er_spread with + | Some spread_expr -> + (match eval env spread_expr with + | VRecord fields -> fields + | _ -> error "Spread must be a record") + | None -> [] + in + let new_fields = List.map (fun (id, expr_opt) -> + let value = match expr_opt with + | Some e -> eval env e + | None -> + (* Shorthand: {x} means {x: x} *) + match Value.consume env id.name with + | Ok v -> v + | Error msg -> error ~span:id.span msg + in + (id.name, value) + ) er_fields in + (* New fields override spread fields *) + let merged = List.fold_left (fun acc (k, v) -> + (k, v) :: List.remove_assoc k acc + ) base_fields new_fields in + VRecord merged + + | ExprRowRestrict (e, field) -> + let value = eval env e in + (match value with + | VRecord fields -> + VRecord (List.remove_assoc field.name fields) + | _ -> error "Cannot restrict non-record value") + + | ExprBinary (e1, op, e2) -> + (* Short-circuit evaluation for && and || *) + (match op with + | OpAnd -> + let v1 = eval env e1 in + (match Value.to_bool v1 with + | Ok false -> VBool false + | Ok true -> eval env e2 + | Error msg -> error msg) + | OpOr -> + let v1 = eval env e1 in + (match Value.to_bool v1 with + | Ok true -> VBool true + | Ok false -> eval env e2 + | Error msg -> error msg) + | _ -> + let v1 = eval env e1 in + let v2 = eval env e2 in + match eval_binop op v1 v2 with + | Ok v -> v + | Error msg -> error msg) + + | ExprUnary (op, e) -> + let v = eval env e in + (match eval_unop op v with + | Ok v -> v + | Error msg -> error msg) + + | ExprBlock blk -> eval_block env blk + + | ExprReturn e_opt -> + let value = match e_opt with + | Some e -> eval env e + | None -> VUnit + in + raise (Return_exn value) + + | ExprVariant (type_id, variant_id) -> + VVariant (type_id.name, variant_id.name, []) + + | ExprTry _ -> error "Try/catch not yet implemented" + | ExprHandle _ -> error "Effect handlers not yet implemented" + | ExprResume _ -> error "Resume not yet implemented" + | ExprUnsafe _ -> error "Unsafe operations not yet implemented" + +(** Evaluate a block of statements *) +and eval_block env { blk_stmts; blk_expr } = + let block_env = child_env env in + List.iter (eval_stmt block_env) blk_stmts; + match blk_expr with + | Some e -> eval block_env e + | None -> VUnit + +(** Evaluate a statement *) +and eval_stmt env = function + | StmtLet { sl_mut; sl_pat; sl_value; _ } -> + let value = eval env sl_value in + let bindings = match match_pattern sl_pat value with + | Some b -> b + | None -> error "Pattern match failed in let binding" + in + List.iter (fun (name, v, _) -> + Value.bind env name v ~mutable_:sl_mut ~linear:false + ) bindings + + | StmtExpr e -> ignore (eval env e) + + | StmtAssign (target, op, value_expr) -> + let new_value = eval env value_expr in + (match target with + | ExprVar id -> + let final_value = match op with + | AssignEq -> new_value + | AssignAdd | AssignSub | AssignMul | AssignDiv as assign_op -> + (match Value.lookup env id.name with + | Some binding -> + let binop = match assign_op with + | AssignAdd -> OpAdd + | AssignSub -> OpSub + | AssignMul -> OpMul + | AssignDiv -> OpDiv + | AssignEq -> OpAdd (* unreachable *) + in + (match eval_binop binop binding.value new_value with + | Ok v -> v + | Error msg -> error msg) + | None -> error ~span:id.span (Printf.sprintf "Unbound variable: %s" id.name)) + in + if not (Value.update env id.name final_value) then + error ~span:id.span (Printf.sprintf "Cannot assign to immutable variable: %s" id.name) + | ExprIndex (arr_expr, idx_expr) -> + let arr = eval env arr_expr in + let idx = eval env idx_expr in + (match arr, idx with + | VArray arr, VInt i -> + if i >= 0 && i < Array.length arr then + arr.(i) <- new_value + else + error (Printf.sprintf "Array index %d out of bounds" i) + | _ -> error "Invalid assignment target") + | ExprField (rec_expr, field) -> + (* Record field update - creates new record *) + let _ = eval env rec_expr in + error ~span:field.span "Direct field mutation not supported; use record update syntax" + | _ -> error "Invalid assignment target") + + | StmtWhile (cond, body) -> + let rec loop () = + match Value.to_bool (eval env cond) with + | Ok true -> + (try + ignore (eval_block env body); + loop () + with + | Break_exn -> () + | Continue_exn -> loop ()) + | Ok false -> () + | Error msg -> error msg + in + loop () + + | StmtFor (pat, iter_expr, body) -> + let iter_val = eval env iter_expr in + let items = match iter_val with + | VArray arr -> Array.to_list arr + | VTuple items -> items + | VString s -> List.init (String.length s) (fun i -> VChar (String.get s i)) + | _ -> error "Cannot iterate over this value" + in + List.iter (fun item -> + match match_pattern pat item with + | None -> error "Pattern match failed in for loop" + | Some bindings -> + let iter_env = child_env env in + List.iter (fun (name, v, _) -> + Value.bind iter_env name v ~mutable_:false ~linear:false + ) bindings; + try ignore (eval_block iter_env body) + with + | Break_exn -> raise Break_exn + | Continue_exn -> () + ) items + +(** Apply a function to arguments *) +and apply func_val args = + match func_val with + | VClosure { params; body; env } -> + if List.length params <> List.length args then + error (Printf.sprintf "Function expects %d arguments, got %d" + (List.length params) (List.length args)); + let call_env = child_env env in + List.iter2 (fun param arg -> + let linear = match param.p_quantity with + | Some QOne -> true + | _ -> false + in + Value.bind call_env param.p_name.name arg ~mutable_:false ~linear + ) params args; + (try eval call_env body + with Return_exn v -> v) + | VBuiltin (_, f) -> f args + | VVariant (ty, variant, existing_args) -> + (* Variant constructor application *) + VVariant (ty, variant, existing_args @ args) + | _ -> error (Printf.sprintf "Cannot call non-function value: %s" (Value.show func_val)) + +(** Evaluate a function declaration and bind it *) +let eval_fn_decl env (decl : fn_decl) = + let closure = VClosure { + params = decl.fd_params; + body = (match decl.fd_body with + | FnBlock blk -> ExprBlock blk + | FnExpr e -> e); + env; + } in + Value.bind env decl.fd_name.name closure ~mutable_:false ~linear:false + +(** Evaluate a type declaration (registers constructors) *) +let eval_type_decl env (decl : type_decl) = + match decl.td_body with + | TyEnum variants -> + List.iter (fun (v : variant_decl) -> + let constructor = + if v.vd_fields = [] then + VVariant (decl.td_name.name, v.vd_name.name, []) + else + (* Return a constructor function *) + VBuiltin (v.vd_name.name, fun args -> + VVariant (decl.td_name.name, v.vd_name.name, args)) + in + Value.bind env v.vd_name.name constructor ~mutable_:false ~linear:false + ) variants + | _ -> () (* Structs/aliases don't create value bindings *) + +(** Evaluate a top-level declaration *) +let eval_top_level env = function + | TopFn decl -> eval_fn_decl env decl + | TopType decl -> eval_type_decl env decl + | TopConst { tc_name; tc_value; _ } -> + let value = eval env tc_value in + Value.bind env tc_name.name value ~mutable_:false ~linear:false + | TopEffect _ -> () (* Effects are compile-time only *) + | TopTrait _ -> () (* Traits are compile-time only *) + | TopImpl _ -> () (* Impls are compile-time only *) + +(** Evaluate a complete program *) +let eval_program env (prog : program) = + List.iter (eval_top_level env) prog.prog_decls diff --git a/lib/lsp.ml b/lib/lsp.ml new file mode 100644 index 0000000..ac93384 --- /dev/null +++ b/lib/lsp.ml @@ -0,0 +1,564 @@ +(** Language Server Protocol implementation for AffineScript *) + +(** LSP Position *) +type position = { + line: int; + character: int; +} +[@@deriving show] + +(** LSP Range *) +type range = { + start: position; + end_: position; +} +[@@deriving show] + +(** LSP Location *) +type location = { + uri: string; + range: range; +} +[@@deriving show] + +(** Diagnostic severity *) +type severity = + | Error + | Warning + | Information + | Hint +[@@deriving show] + +(** LSP Diagnostic *) +type diagnostic = { + range: range; + severity: severity; + code: string option; + source: string; + message: string; +} +[@@deriving show] + +(** Completion item kind *) +type completion_kind = + | CKText + | CKMethod + | CKFunction + | CKConstructor + | CKField + | CKVariable + | CKClass + | CKInterface + | CKModule + | CKProperty + | CKUnit + | CKValue + | CKEnum + | CKKeyword + | CKSnippet + | CKColor + | CKFile + | CKReference + | CKFolder + | CKEnumMember + | CKConstant + | CKStruct + | CKEvent + | CKOperator + | CKTypeParameter +[@@deriving show] + +(** Completion item *) +type completion_item = { + label: string; + kind: completion_kind; + detail: string option; + documentation: string option; + insert_text: string option; + insert_text_format: int; (** 1 = plain, 2 = snippet *) +} +[@@deriving show] + +(** Hover content *) +type hover = { + contents: string; + range: range option; +} +[@@deriving show] + +(** Symbol kind *) +type symbol_kind = + | SKFile + | SKModule + | SKNamespace + | SKPackage + | SKClass + | SKMethod + | SKProperty + | SKField + | SKConstructor + | SKEnum + | SKInterface + | SKFunction + | SKVariable + | SKConstant + | SKString + | SKNumber + | SKBoolean + | SKArray + | SKObject + | SKKey + | SKNull + | SKEnumMember + | SKStruct + | SKEvent + | SKOperator + | SKTypeParameter +[@@deriving show] + +(** Document symbol *) +type document_symbol = { + name: string; + kind: symbol_kind; + range: range; + selection_range: range; + children: document_symbol list; +} +[@@deriving show] + +(** Server capabilities *) +type capabilities = { + text_document_sync: int; (** 0=None, 1=Full, 2=Incremental *) + hover_provider: bool; + completion_provider: bool; + definition_provider: bool; + references_provider: bool; + document_symbol_provider: bool; + workspace_symbol_provider: bool; + code_action_provider: bool; + rename_provider: bool; + formatting_provider: bool; + signature_help_provider: bool; +} +[@@deriving show] + +let default_capabilities = { + text_document_sync = 1; + hover_provider = true; + completion_provider = true; + definition_provider = true; + references_provider = true; + document_symbol_provider = true; + workspace_symbol_provider = true; + code_action_provider = true; + rename_provider = true; + formatting_provider = true; + signature_help_provider = true; +} + +(** Document state *) +type document = { + doc_uri: string; + doc_version: int; + doc_content: string; + mutable doc_ast: Ast.program option; + mutable doc_symbols: Symbol.symbol_table option; + mutable doc_diagnostics: diagnostic list; +} + +(** LSP Server state *) +type server_state = { + mutable initialized: bool; + mutable shutdown: bool; + documents: (string, document) Hashtbl.t; + mutable root_uri: string option; + mutable capabilities: capabilities; +} + +let create_server () = { + initialized = false; + shutdown = false; + documents = Hashtbl.create 32; + root_uri = None; + capabilities = default_capabilities; +} + +(** Convert span to LSP range *) +let span_to_range (span: Span.t) : range = { + start = { line = span.start_pos.line - 1; character = span.start_pos.col - 1 }; + end_ = { line = span.end_pos.line - 1; character = span.end_pos.col - 1 }; +} + +(** Parse and analyze document *) +let analyze_document doc = + try + let prog = Parse_driver.parse_string ~file:doc.doc_uri doc.doc_content in + doc.doc_ast <- Some prog; + let (symtab, resolve_errors) = Resolve.resolve_program prog in + doc.doc_symbols <- Some symtab; + let type_errors = Typecheck.check_program prog in + let borrow_errors = Borrow.check_program prog in + (* Convert errors to diagnostics *) + let resolve_diags = List.map (fun err -> + let msg = Resolve.error_to_string err in + let range = match err with + | Resolve.UnboundVariable (_, span) -> span_to_range span + | Resolve.UnboundType (_, span) -> span_to_range span + | Resolve.DuplicateDefinition (_, span, _) -> span_to_range span + | _ -> { start = { line = 0; character = 0 }; end_ = { line = 0; character = 0 } } + in + { range; severity = Error; code = Some "E0001"; source = "affinescript"; message = msg } + ) resolve_errors in + let type_diags = List.map (fun err -> + let msg = Typecheck.error_to_string err in + { range = { start = { line = 0; character = 0 }; end_ = { line = 0; character = 0 } }; + severity = Error; code = Some "E0002"; source = "affinescript"; message = msg } + ) type_errors in + let borrow_diags = List.map (fun err -> + let msg = Borrow.error_to_string err in + { range = { start = { line = 0; character = 0 }; end_ = { line = 0; character = 0 } }; + severity = Error; code = Some "E0003"; source = "affinescript"; message = msg } + ) borrow_errors in + doc.doc_diagnostics <- resolve_diags @ type_diags @ borrow_diags + with + | Lexer.Lexer_error (msg, pos) -> + doc.doc_diagnostics <- [{ + range = { start = { line = pos.Span.line - 1; character = pos.Span.col - 1 }; + end_ = { line = pos.Span.line - 1; character = pos.Span.col } }; + severity = Error; + code = Some "E0000"; + source = "affinescript"; + message = "Lexer error: " ^ msg; + }] + | Parse_driver.Parse_error (msg, span) -> + doc.doc_diagnostics <- [{ + range = span_to_range span; + severity = Error; + code = Some "E0000"; + source = "affinescript"; + message = "Parse error: " ^ msg; + }] + +(** Get completions at position *) +let get_completions server uri pos = + match Hashtbl.find_opt server.documents uri with + | None -> [] + | Some doc -> + (* Get symbols in scope at position *) + let items = ref [] in + (* Add keywords *) + let keywords = [ + "fn"; "let"; "mut"; "if"; "else"; "match"; "while"; "for"; "in"; + "return"; "break"; "continue"; "struct"; "enum"; "trait"; "impl"; + "type"; "effect"; "handle"; "resume"; "pub"; "use"; "module"; + "true"; "false"; "own"; "ref"; "total"; "unsafe"; + ] in + List.iter (fun kw -> + items := { label = kw; kind = CKKeyword; detail = None; + documentation = None; insert_text = None; insert_text_format = 1 } :: !items + ) keywords; + (* Add built-in types *) + let types = ["Int"; "Bool"; "Float"; "Char"; "String"; "Unit"; "Never"; "Array"; "Option"; "Result"] in + List.iter (fun ty -> + items := { label = ty; kind = CKClass; detail = Some "built-in type"; + documentation = None; insert_text = None; insert_text_format = 1 } :: !items + ) types; + (* Add built-in functions *) + let builtins = [ + ("print", "fn(any) -> ()"); + ("println", "fn(any) -> ()"); + ("len", "fn(Array[T]) -> Int"); + ("range", "fn(Int) -> Array[Int]"); + ("map", "fn((T -> U), Array[T]) -> Array[U]"); + ("filter", "fn((T -> Bool), Array[T]) -> Array[T]"); + ("fold", "fn((A, T -> A), A, Array[T]) -> A"); + ] in + List.iter (fun (name, sig_) -> + items := { label = name; kind = CKFunction; detail = Some sig_; + documentation = None; insert_text = None; insert_text_format = 1 } :: !items + ) builtins; + (* Add symbols from document *) + (match doc.doc_symbols with + | Some symtab -> + let symbols = Symbol.symbols_in_scope (Symbol.current_scope symtab) in + List.iter (fun sym -> + let kind = match sym.Symbol.sym_kind with + | Symbol.SKVariable -> CKVariable + | Symbol.SKFunction -> CKFunction + | Symbol.SKType -> CKClass + | Symbol.SKTrait -> CKInterface + | Symbol.SKEffect -> CKEvent + | Symbol.SKVariant -> CKEnumMember + | Symbol.SKField -> CKField + | Symbol.SKModule -> CKModule + | _ -> CKValue + in + items := { label = sym.Symbol.sym_name; kind; detail = None; + documentation = None; insert_text = None; insert_text_format = 1 } :: !items + ) symbols + | None -> ()); + !items + +(** Get hover information *) +let get_hover server uri pos = + match Hashtbl.find_opt server.documents uri with + | None -> None + | Some doc -> + (* Find symbol at position *) + (* For now, return generic info *) + Some { contents = "AffineScript symbol"; range = None } + +(** Get definition location *) +let get_definition server uri pos = + match Hashtbl.find_opt server.documents uri with + | None -> None + | Some doc -> + (* Find definition of symbol at position *) + None + +(** Get references to symbol *) +let get_references server uri pos = + match Hashtbl.find_opt server.documents uri with + | None -> [] + | Some doc -> + (* Find all references to symbol at position *) + [] + +(** Get document symbols *) +let get_document_symbols server uri = + match Hashtbl.find_opt server.documents uri with + | None -> [] + | Some doc -> + match doc.doc_ast with + | None -> [] + | Some prog -> + (* Convert AST to document symbols *) + let symbols = ref [] in + List.iter (fun decl -> + match decl with + | Ast.TopFn fd -> + let range = span_to_range fd.fd_name.span in + symbols := { + name = fd.fd_name.name; + kind = SKFunction; + range; + selection_range = range; + children = []; + } :: !symbols + | Ast.TopType td -> + let range = span_to_range td.td_name.span in + let kind = match td.td_body with + | Ast.TyEnum _ -> SKEnum + | Ast.TyStruct _ -> SKStruct + | Ast.TyAlias _ -> SKClass + in + symbols := { + name = td.td_name.name; + kind; + range; + selection_range = range; + children = []; + } :: !symbols + | Ast.TopTrait td -> + let range = span_to_range td.trd_name.span in + symbols := { + name = td.trd_name.name; + kind = SKInterface; + range; + selection_range = range; + children = []; + } :: !symbols + | Ast.TopEffect ed -> + let range = span_to_range ed.ed_name.span in + symbols := { + name = ed.ed_name.name; + kind = SKEvent; + range; + selection_range = range; + children = []; + } :: !symbols + | Ast.TopConst tc -> + let range = span_to_range tc.tc_name.span in + symbols := { + name = tc.tc_name.name; + kind = SKConstant; + range; + selection_range = range; + children = []; + } :: !symbols + | _ -> () + ) prog.prog_decls; + !symbols + +(** Format document *) +let format_document _server _uri = + (* Would implement pretty printer *) + [] + +(** JSON-RPC message handling *) + +type json = + | JNull + | JBool of bool + | JInt of int + | JFloat of float + | JString of string + | JArray of json list + | JObject of (string * json) list + +let rec json_to_string = function + | JNull -> "null" + | JBool b -> if b then "true" else "false" + | JInt i -> string_of_int i + | JFloat f -> Printf.sprintf "%g" f + | JString s -> Printf.sprintf "\"%s\"" (String.escaped s) + | JArray arr -> + "[" ^ String.concat ", " (List.map json_to_string arr) ^ "]" + | JObject obj -> + "{" ^ String.concat ", " (List.map (fun (k, v) -> + Printf.sprintf "\"%s\": %s" k (json_to_string v) + ) obj) ^ "}" + +(** LSP message types *) +type lsp_message = + | Request of int * string * json + | Notification of string * json + | Response of int * json option * json option + +(** Handle initialize request *) +let handle_initialize server params = + server.initialized <- true; + JObject [ + ("capabilities", JObject [ + ("textDocumentSync", JInt 1); + ("hoverProvider", JBool true); + ("completionProvider", JObject [ + ("triggerCharacters", JArray [JString "."; JString ":"]); + ]); + ("definitionProvider", JBool true); + ("referencesProvider", JBool true); + ("documentSymbolProvider", JBool true); + ("workspaceSymbolProvider", JBool true); + ("codeActionProvider", JBool true); + ("renameProvider", JBool true); + ("documentFormattingProvider", JBool true); + ]); + ("serverInfo", JObject [ + ("name", JString "affinescript-lsp"); + ("version", JString "0.1.0"); + ]); + ] + +(** Handle shutdown request *) +let handle_shutdown server = + server.shutdown <- true; + JNull + +(** Handle textDocument/didOpen *) +let handle_did_open server params = + (* Extract URI and content from params *) + let uri = "unknown" in (* Would parse from params *) + let content = "" in + let doc = { + doc_uri = uri; + doc_version = 1; + doc_content = content; + doc_ast = None; + doc_symbols = None; + doc_diagnostics = []; + } in + Hashtbl.replace server.documents uri doc; + analyze_document doc + +(** Handle textDocument/didChange *) +let handle_did_change server params = + let uri = "unknown" in (* Would parse from params *) + match Hashtbl.find_opt server.documents uri with + | None -> () + | Some doc -> + (* Update content and reanalyze *) + analyze_document doc + +(** Handle textDocument/completion *) +let handle_completion server params = + let uri = "unknown" in + let pos = { line = 0; character = 0 } in + let items = get_completions server uri pos in + JObject [ + ("isIncomplete", JBool false); + ("items", JArray (List.map (fun item -> + JObject [ + ("label", JString item.label); + ("kind", JInt (match item.kind with CKFunction -> 3 | CKVariable -> 6 | CKKeyword -> 14 | _ -> 1)); + ] + ) items)); + ] + +(** Handle textDocument/hover *) +let handle_hover server params = + let uri = "unknown" in + let pos = { line = 0; character = 0 } in + match get_hover server uri pos with + | None -> JNull + | Some hover -> + JObject [ + ("contents", JObject [ + ("kind", JString "markdown"); + ("value", JString hover.contents); + ]); + ] + +(** Handle textDocument/documentSymbol *) +let handle_document_symbol server params = + let uri = "unknown" in + let symbols = get_document_symbols server uri in + JArray (List.map (fun sym -> + JObject [ + ("name", JString sym.name); + ("kind", JInt (match sym.kind with SKFunction -> 12 | SKClass -> 5 | SKEnum -> 10 | _ -> 1)); + ("range", JObject [ + ("start", JObject [("line", JInt sym.range.start.line); ("character", JInt sym.range.start.character)]); + ("end", JObject [("line", JInt sym.range.end_.line); ("character", JInt sym.range.end_.character)]); + ]); + ("selectionRange", JObject [ + ("start", JObject [("line", JInt sym.selection_range.start.line); ("character", JInt sym.selection_range.start.character)]); + ("end", JObject [("line", JInt sym.selection_range.end_.line); ("character", JInt sym.selection_range.end_.character)]); + ]); + ] + ) symbols) + +(** Main message dispatcher *) +let handle_message server msg = + match msg with + | Request (id, "initialize", params) -> + Some (Response (id, Some (handle_initialize server params), None)) + | Request (id, "shutdown", _) -> + Some (Response (id, Some (handle_shutdown server), None)) + | Request (id, "textDocument/completion", params) -> + Some (Response (id, Some (handle_completion server params), None)) + | Request (id, "textDocument/hover", params) -> + Some (Response (id, Some (handle_hover server params), None)) + | Request (id, "textDocument/documentSymbol", params) -> + Some (Response (id, Some (handle_document_symbol server params), None)) + | Notification ("initialized", _) -> None + | Notification ("textDocument/didOpen", params) -> + handle_did_open server params; + None + | Notification ("textDocument/didChange", params) -> + handle_did_change server params; + None + | Notification ("exit", _) -> + if server.shutdown then exit 0 else exit 1 + | _ -> None + +(** Run LSP server *) +let run () = + let server = create_server () in + (* Read messages from stdin, write to stdout *) + (* Simplified - would need proper JSON-RPC framing *) + Format.eprintf "AffineScript LSP server starting...@."; + let rec loop () = + (* Would read Content-Length header and JSON body *) + loop () + in + loop () diff --git a/lib/package.ml b/lib/package.ml new file mode 100644 index 0000000..3af9e42 --- /dev/null +++ b/lib/package.ml @@ -0,0 +1,397 @@ +(** Package manager for AffineScript *) + +(** Semantic version *) +type version = { + major: int; + minor: int; + patch: int; + prerelease: string option; +} +[@@deriving show] + +let parse_version s = + let parts = String.split_on_char '.' s in + match parts with + | [major; minor; patch] -> + (try + let patch, pre = match String.split_on_char '-' patch with + | [p] -> (p, None) + | [p; pre] -> (p, Some pre) + | _ -> (patch, None) + in + Some { major = int_of_string major; + minor = int_of_string minor; + patch = int_of_string patch; + prerelease = pre } + with _ -> None) + | [major; minor] -> + (try Some { major = int_of_string major; + minor = int_of_string minor; + patch = 0; + prerelease = None } + with _ -> None) + | _ -> None + +let string_of_version v = + let base = Printf.sprintf "%d.%d.%d" v.major v.minor v.patch in + match v.prerelease with + | Some pre -> base ^ "-" ^ pre + | None -> base + +let compare_version v1 v2 = + let c = compare v1.major v2.major in + if c <> 0 then c + else let c = compare v1.minor v2.minor in + if c <> 0 then c + else let c = compare v1.patch v2.patch in + if c <> 0 then c + else match v1.prerelease, v2.prerelease with + | None, None -> 0 + | None, Some _ -> 1 (* Release > prerelease *) + | Some _, None -> -1 + | Some p1, Some p2 -> String.compare p1 p2 + +(** Version constraint *) +type version_constraint = + | VExact of version (** = 1.2.3 *) + | VGreater of version (** > 1.2.3 *) + | VGreaterEq of version (** >= 1.2.3 *) + | VLess of version (** < 1.2.3 *) + | VLessEq of version (** <= 1.2.3 *) + | VCaret of version (** ^1.2.3 (compatible) *) + | VTilde of version (** ~1.2.3 (patch-level) *) + | VAny (** * *) + | VAnd of version_constraint * version_constraint + | VOr of version_constraint * version_constraint +[@@deriving show] + +let rec satisfies_constraint v = function + | VExact v2 -> compare_version v v2 = 0 + | VGreater v2 -> compare_version v v2 > 0 + | VGreaterEq v2 -> compare_version v v2 >= 0 + | VLess v2 -> compare_version v v2 < 0 + | VLessEq v2 -> compare_version v v2 <= 0 + | VCaret v2 -> + (* ^1.2.3 means >=1.2.3, <2.0.0 (or <0.2.0 if major=0, <0.0.4 if minor=0) *) + compare_version v v2 >= 0 && + (if v2.major > 0 then v.major = v2.major + else if v2.minor > 0 then v.major = 0 && v.minor = v2.minor + else v.major = 0 && v.minor = 0 && v.patch = v2.patch) + | VTilde v2 -> + (* ~1.2.3 means >=1.2.3, <1.3.0 *) + compare_version v v2 >= 0 && + v.major = v2.major && v.minor = v2.minor + | VAny -> true + | VAnd (c1, c2) -> satisfies_constraint v c1 && satisfies_constraint v c2 + | VOr (c1, c2) -> satisfies_constraint v c1 || satisfies_constraint v c2 + +(** Dependency specification *) +type dependency = { + dep_name: string; + dep_version: version_constraint; + dep_optional: bool; + dep_features: string list; +} +[@@deriving show] + +(** Package target *) +type target_type = + | TargetLib + | TargetBin + | TargetTest + | TargetBench +[@@deriving show] + +type target = { + tgt_name: string; + tgt_type: target_type; + tgt_path: string; + tgt_deps: string list; (** Internal deps *) +} +[@@deriving show] + +(** Package manifest (afs.toml) *) +type manifest = { + pkg_name: string; + pkg_version: version; + pkg_authors: string list; + pkg_license: string option; + pkg_description: string option; + pkg_repository: string option; + pkg_homepage: string option; + pkg_keywords: string list; + pkg_categories: string list; + pkg_edition: string; (** AffineScript edition *) + pkg_dependencies: dependency list; + pkg_dev_dependencies: dependency list; + pkg_build_dependencies: dependency list; + pkg_features: (string * string list) list; (** Feature flags *) + pkg_default_features: string list; + pkg_targets: target list; +} +[@@deriving show] + +let empty_manifest name = { + pkg_name = name; + pkg_version = { major = 0; minor = 1; patch = 0; prerelease = None }; + pkg_authors = []; + pkg_license = None; + pkg_description = None; + pkg_repository = None; + pkg_homepage = None; + pkg_keywords = []; + pkg_categories = []; + pkg_edition = "2024"; + pkg_dependencies = []; + pkg_dev_dependencies = []; + pkg_build_dependencies = []; + pkg_features = []; + pkg_default_features = []; + pkg_targets = [{ tgt_name = name; tgt_type = TargetLib; tgt_path = "lib"; tgt_deps = [] }]; +} + +(** Lock file entry *) +type lock_entry = { + le_name: string; + le_version: version; + le_source: string; (** registry, git, path *) + le_checksum: string option; + le_dependencies: (string * version) list; +} +[@@deriving show] + +(** Lock file *) +type lock_file = { + lf_version: int; + lf_entries: lock_entry list; +} +[@@deriving show] + +(** Package source *) +type package_source = + | SourceRegistry of string (** Official registry *) + | SourceGit of string * string (** URL, ref *) + | SourcePath of string (** Local path *) +[@@deriving show] + +(** Resolved package *) +type resolved_package = { + rp_manifest: manifest; + rp_source: package_source; + rp_path: string; (** Local path after fetch *) +} + +(** Package registry *) +type registry = { + reg_url: string; + reg_name: string; +} + +let default_registry = { + reg_url = "https://packages.affinescript.org"; + reg_name = "afs"; +} + +(** Package manager state *) +type pm_state = { + pm_root: string; (** Project root *) + pm_cache_dir: string; (** Global cache directory *) + pm_manifest: manifest option; + pm_lock: lock_file option; + pm_registries: registry list; +} + +let init_state root = + let home = Sys.getenv_opt "HOME" |> Option.value ~default:"/tmp" in + { + pm_root = root; + pm_cache_dir = Filename.concat home ".afs/cache"; + pm_manifest = None; + pm_lock = None; + pm_registries = [default_registry]; + } + +(** Parse manifest from TOML (simplified) *) +let parse_manifest_string content = + (* Simplified TOML parsing - would use a real parser in production *) + let lines = String.split_on_char '\n' content in + let manifest = ref (empty_manifest "unknown") in + let current_section = ref "package" in + List.iter (fun line -> + let line = String.trim line in + if String.length line > 0 && line.[0] = '[' then + current_section := String.sub line 1 (String.length line - 2) + else if String.contains line '=' then + let parts = String.split_on_char '=' line in + match parts with + | [key; value] -> + let key = String.trim key in + let value = String.trim value in + let value = if String.length value > 1 && value.[0] = '"' then + String.sub value 1 (String.length value - 2) + else value in + (match !current_section, key with + | "package", "name" -> manifest := { !manifest with pkg_name = value } + | "package", "version" -> + (match parse_version value with + | Some v -> manifest := { !manifest with pkg_version = v } + | None -> ()) + | "package", "license" -> manifest := { !manifest with pkg_license = Some value } + | "package", "description" -> manifest := { !manifest with pkg_description = Some value } + | _ -> ()) + | _ -> () + ) lines; + !manifest + +(** Read manifest from file *) +let read_manifest path = + try + let ic = open_in path in + let content = really_input_string ic (in_channel_length ic) in + close_in ic; + Some (parse_manifest_string content) + with _ -> None + +(** Write manifest to file *) +let write_manifest path manifest = + let oc = open_out path in + Printf.fprintf oc "[package]\n"; + Printf.fprintf oc "name = \"%s\"\n" manifest.pkg_name; + Printf.fprintf oc "version = \"%s\"\n" (string_of_version manifest.pkg_version); + Option.iter (Printf.fprintf oc "license = \"%s\"\n") manifest.pkg_license; + Option.iter (Printf.fprintf oc "description = \"%s\"\n") manifest.pkg_description; + Printf.fprintf oc "edition = \"%s\"\n" manifest.pkg_edition; + if manifest.pkg_dependencies <> [] then begin + Printf.fprintf oc "\n[dependencies]\n"; + List.iter (fun dep -> + Printf.fprintf oc "%s = \"%s\"\n" dep.dep_name + (show_version_constraint dep.dep_version) + ) manifest.pkg_dependencies + end; + close_out oc + +(** Dependency resolution *) + +type resolution_error = + | PackageNotFound of string + | VersionNotFound of string * version_constraint + | ConflictingVersions of string * version * version + | CyclicDependency of string list +[@@deriving show] + +(** Build dependency graph *) +type dep_graph = { + dg_nodes: (string, manifest) Hashtbl.t; + dg_edges: (string, string list) Hashtbl.t; +} + +let create_dep_graph () = { + dg_nodes = Hashtbl.create 32; + dg_edges = Hashtbl.create 32; +} + +(** Topological sort of dependencies *) +let topo_sort graph = + let visited = Hashtbl.create 32 in + let temp = Hashtbl.create 32 in + let result = ref [] in + let rec visit name = + if Hashtbl.mem temp name then + Error (CyclicDependency [name]) + else if not (Hashtbl.mem visited name) then begin + Hashtbl.replace temp name true; + let deps = Hashtbl.find_opt graph.dg_edges name |> Option.value ~default:[] in + let errors = List.filter_map (fun dep -> + match visit dep with + | Error e -> Some e + | Ok () -> None + ) deps in + match errors with + | e :: _ -> Error e + | [] -> + Hashtbl.remove temp name; + Hashtbl.replace visited name true; + result := name :: !result; + Ok () + end + else Ok () + in + let nodes = Hashtbl.fold (fun name _ acc -> name :: acc) graph.dg_nodes [] in + let errors = List.filter_map (fun name -> + match visit name with + | Error e -> Some e + | Ok () -> None + ) nodes in + match errors with + | e :: _ -> Error e + | [] -> Ok !result + +(** Package manager commands *) + +let cmd_init state name = + let manifest = empty_manifest name in + let manifest_path = Filename.concat state.pm_root "afs.toml" in + write_manifest manifest_path manifest; + (* Create directory structure *) + let _ = Sys.command (Printf.sprintf "mkdir -p %s/lib" state.pm_root) in + let _ = Sys.command (Printf.sprintf "mkdir -p %s/bin" state.pm_root) in + let _ = Sys.command (Printf.sprintf "mkdir -p %s/test" state.pm_root) in + (* Create main.afs *) + let main_path = Filename.concat state.pm_root "lib/main.afs" in + let oc = open_out main_path in + Printf.fprintf oc "// %s - AffineScript project\n\n" name; + Printf.fprintf oc "pub fn hello() -> String {\n"; + Printf.fprintf oc " \"Hello from %s!\"\n" name; + Printf.fprintf oc "}\n"; + close_out oc; + Ok () + +let cmd_add state dep_name version_str = + match state.pm_manifest with + | None -> Error "No manifest found. Run 'afs init' first." + | Some manifest -> + let version_constraint = match version_str with + | "*" -> VAny + | s when String.length s > 0 && s.[0] = '^' -> + (match parse_version (String.sub s 1 (String.length s - 1)) with + | Some v -> VCaret v + | None -> VAny) + | s -> + (match parse_version s with + | Some v -> VCaret v (* Default to caret *) + | None -> VAny) + in + let dep = { + dep_name; + dep_version = version_constraint; + dep_optional = false; + dep_features = []; + } in + let manifest' = { manifest with + pkg_dependencies = dep :: manifest.pkg_dependencies + } in + let manifest_path = Filename.concat state.pm_root "afs.toml" in + write_manifest manifest_path manifest'; + Ok () + +let cmd_remove state dep_name = + match state.pm_manifest with + | None -> Error "No manifest found." + | Some manifest -> + let deps' = List.filter (fun d -> d.dep_name <> dep_name) manifest.pkg_dependencies in + let manifest' = { manifest with pkg_dependencies = deps' } in + let manifest_path = Filename.concat state.pm_root "afs.toml" in + write_manifest manifest_path manifest'; + Ok () + +let cmd_build _state = + (* Would invoke compiler on all targets *) + Ok () + +let cmd_test _state = + (* Would run test targets *) + Ok () + +let cmd_publish _state = + (* Would publish to registry *) + Ok () diff --git a/lib/repl.ml b/lib/repl.ml new file mode 100644 index 0000000..78320d1 --- /dev/null +++ b/lib/repl.ml @@ -0,0 +1,407 @@ +(** Interactive REPL for AffineScript *) + +(** REPL state *) +type state = { + env: Value.env; + mutable counter: int; + mutable debug: bool; + mutable multiline: bool; + mutable buffer: string; +} + +let create_state () = { + env = Value.empty_env (); + counter = 0; + debug = false; + multiline = false; + buffer = ""; +} + +(** Built-in functions for the REPL environment *) +let add_builtins env = + (* print: any -> () *) + Value.bind env "print" (Value.VBuiltin ("print", fun args -> + List.iter (fun v -> Format.printf "%s@." (Value.show v)) args; + Value.VUnit + )) ~mutable_:false ~linear:false; + + (* println: any -> () *) + Value.bind env "println" (Value.VBuiltin ("println", fun args -> + List.iter (fun v -> Format.printf "%s@." (Value.show v)) args; + Value.VUnit + )) ~mutable_:false ~linear:false; + + (* str: any -> String *) + Value.bind env "str" (Value.VBuiltin ("str", fun args -> + match args with + | [v] -> Value.VString (Value.show v) + | _ -> failwith "str expects 1 argument" + )) ~mutable_:false ~linear:false; + + (* len: array|string|tuple -> Int *) + Value.bind env "len" (Value.VBuiltin ("len", fun args -> + match args with + | [Value.VArray a] -> Value.VInt (Array.length a) + | [Value.VString s] -> Value.VInt (String.length s) + | [Value.VTuple t] -> Value.VInt (List.length t) + | [Value.VRecord r] -> Value.VInt (List.length r) + | _ -> failwith "len expects an array, string, tuple, or record" + )) ~mutable_:false ~linear:false; + + (* type_of: any -> String *) + Value.bind env "type_of" (Value.VBuiltin ("type_of", fun args -> + match args with + | [v] -> + let ty = match v with + | Value.VUnit -> "Unit" + | Value.VBool _ -> "Bool" + | Value.VInt _ -> "Int" + | Value.VFloat _ -> "Float" + | Value.VChar _ -> "Char" + | Value.VString _ -> "String" + | Value.VTuple _ -> "Tuple" + | Value.VArray _ -> "Array" + | Value.VRecord _ -> "Record" + | Value.VVariant (ty, _, _) -> ty + | Value.VClosure _ -> "Function" + | Value.VBuiltin _ -> "Builtin" + | Value.VRef _ -> "Ref" + in + Value.VString ty + | _ -> failwith "type_of expects 1 argument" + )) ~mutable_:false ~linear:false; + + (* range: Int -> Int -> Array[Int] *) + Value.bind env "range" (Value.VBuiltin ("range", fun args -> + match args with + | [Value.VInt start; Value.VInt stop] -> + Value.VArray (Array.init (max 0 (stop - start)) (fun i -> Value.VInt (start + i))) + | [Value.VInt stop] -> + Value.VArray (Array.init (max 0 stop) (fun i -> Value.VInt i)) + | _ -> failwith "range expects 1 or 2 integer arguments" + )) ~mutable_:false ~linear:false; + + (* push: Array[T] -> T -> Array[T] (returns new array) *) + Value.bind env "push" (Value.VBuiltin ("push", fun args -> + match args with + | [Value.VArray arr; elem] -> + Value.VArray (Array.append arr [|elem|]) + | _ -> failwith "push expects an array and an element" + )) ~mutable_:false ~linear:false; + + (* head: Array[T] -> T *) + Value.bind env "head" (Value.VBuiltin ("head", fun args -> + match args with + | [Value.VArray arr] when Array.length arr > 0 -> arr.(0) + | [Value.VArray _] -> failwith "head: empty array" + | _ -> failwith "head expects a non-empty array" + )) ~mutable_:false ~linear:false; + + (* tail: Array[T] -> Array[T] *) + Value.bind env "tail" (Value.VBuiltin ("tail", fun args -> + match args with + | [Value.VArray arr] when Array.length arr > 0 -> + Value.VArray (Array.sub arr 1 (Array.length arr - 1)) + | [Value.VArray _] -> failwith "tail: empty array" + | _ -> failwith "tail expects a non-empty array" + )) ~mutable_:false ~linear:false; + + (* map: (T -> U) -> Array[T] -> Array[U] *) + Value.bind env "map" (Value.VBuiltin ("map", fun args -> + match args with + | [f; Value.VArray arr] -> + Value.VArray (Array.map (fun x -> Eval.apply f [x]) arr) + | _ -> failwith "map expects a function and an array" + )) ~mutable_:false ~linear:false; + + (* filter: (T -> Bool) -> Array[T] -> Array[T] *) + Value.bind env "filter" (Value.VBuiltin ("filter", fun args -> + match args with + | [f; Value.VArray arr] -> + let filtered = Array.to_list arr |> List.filter (fun x -> + match Eval.apply f [x] with + | Value.VBool true -> true + | _ -> false + ) in + Value.VArray (Array.of_list filtered) + | _ -> failwith "filter expects a function and an array" + )) ~mutable_:false ~linear:false; + + (* fold: (A -> T -> A) -> A -> Array[T] -> A *) + Value.bind env "fold" (Value.VBuiltin ("fold", fun args -> + match args with + | [f; init; Value.VArray arr] -> + Array.fold_left (fun acc x -> Eval.apply f [acc; x]) init arr + | _ -> failwith "fold expects a function, initial value, and an array" + )) ~mutable_:false ~linear:false; + + (* assert: Bool -> () *) + Value.bind env "assert" (Value.VBuiltin ("assert", fun args -> + match args with + | [Value.VBool true] -> Value.VUnit + | [Value.VBool false] -> failwith "Assertion failed" + | _ -> failwith "assert expects a boolean" + )) ~mutable_:false ~linear:false; + + (* panic: String -> Never *) + Value.bind env "panic" (Value.VBuiltin ("panic", fun args -> + match args with + | [Value.VString msg] -> failwith ("panic: " ^ msg) + | _ -> failwith "panic expects a string message" + )) ~mutable_:false ~linear:false; + + () + +(** REPL commands *) +type command = + | CmdHelp + | CmdQuit + | CmdClear + | CmdEnv + | CmdDebug + | CmdLoad of string + | CmdType of string + | CmdAst of string + | CmdNone + | CmdInput of string + +let parse_command input = + let input = String.trim input in + if String.length input = 0 then CmdNone + else if input.[0] = ':' then + let cmd = String.sub input 1 (String.length input - 1) in + let parts = String.split_on_char ' ' cmd in + match parts with + | ["help"] | ["h"] | ["?"] -> CmdHelp + | ["quit"] | ["q"] | ["exit"] -> CmdQuit + | ["clear"] | ["c"] -> CmdClear + | ["env"] | ["e"] -> CmdEnv + | ["debug"] | ["d"] -> CmdDebug + | ["load"; file] | ["l"; file] -> CmdLoad file + | "type" :: rest | "t" :: rest -> CmdType (String.concat " " rest) + | "ast" :: rest | "a" :: rest -> CmdAst (String.concat " " rest) + | _ -> CmdInput input (* Unknown command, treat as expression *) + else CmdInput input + +(** Check if input is complete (balanced braces/parens) *) +let is_complete input = + let depth = ref 0 in + String.iter (fun c -> + match c with + | '(' | '[' | '{' -> incr depth + | ')' | ']' | '}' -> decr depth + | _ -> () + ) input; + !depth <= 0 + +(** Print REPL help *) +let print_help () = + Format.printf "@["; + Format.printf "AffineScript REPL Commands:@."; + Format.printf " :help, :h, :? Show this help@."; + Format.printf " :quit, :q Exit the REPL@."; + Format.printf " :clear, :c Clear the environment@."; + Format.printf " :env, :e Show bound variables@."; + Format.printf " :debug, :d Toggle debug mode@."; + Format.printf " :load Load and execute a file@."; + Format.printf " :type Show the AST of an expression@."; + Format.printf " :ast Parse and show AST@."; + Format.printf "@."; + Format.printf "Examples:@."; + Format.printf " > let x = 42@."; + Format.printf " > x + 1@."; + Format.printf " > fn add(a: Int, b: Int) -> Int { a + b }@."; + Format.printf " > add(1, 2)@."; + Format.printf "@]" + +(** Print environment *) +let print_env env = + Format.printf "@[Bound variables:@."; + Hashtbl.iter (fun name binding -> + let prefix = if binding.Value.mutable_ then "mut " else "" in + Format.printf " %s%s = %s@." prefix name (Value.show binding.Value.value) + ) env.Value.bindings; + Format.printf "@]" + +(** Load and execute a file *) +let load_file state filename = + try + let ic = open_in filename in + let n = in_channel_length ic in + let source = really_input_string ic n in + close_in ic; + let prog = Parse_driver.parse_string ~file:filename source in + Eval.eval_program state.env prog; + Format.printf "Loaded %s@." filename + with + | Sys_error msg -> Format.eprintf "Error: %s@." msg + | Parse_driver.Parse_error (msg, span) -> + Format.eprintf "%a: parse error: %s@." Span.pp_short span msg + | Lexer.Lexer_error (msg, pos) -> + Format.eprintf "%s:%d:%d: lexer error: %s@." filename pos.Span.line pos.Span.col msg + | Eval.Runtime_error (msg, _) -> + Format.eprintf "Runtime error: %s@." msg + +(** Try to parse and evaluate as expression *) +let try_eval_expr state input = + try + let expr = Parse_driver.parse_expr ~file:"" input in + if state.debug then + Format.printf "AST: %s@." (Ast.show_expr expr); + let result = Eval.eval state.env expr in + state.counter <- state.counter + 1; + let var_name = Printf.sprintf "_%d" state.counter in + Value.bind state.env var_name result ~mutable_:false ~linear:false; + Format.printf "%s = %s@." var_name (Value.show result); + true + with _ -> false + +(** Try to parse and evaluate as statement/declaration *) +let try_eval_decl state input = + try + (* Wrap in a module to parse as program *) + let prog = Parse_driver.parse_string ~file:"" input in + if state.debug then + Format.printf "AST: %s@." (Ast.show_program prog); + Eval.eval_program state.env prog; + true + with _ -> false + +(** Evaluate REPL input *) +let eval_input state input = + (* First try as expression *) + if try_eval_expr state input then () + (* Then try as declaration/statement *) + else if try_eval_decl state input then () + else + (* If both fail, show parse error *) + try + let _ = Parse_driver.parse_expr ~file:"" input in + () + with + | Parse_driver.Parse_error (msg, span) -> + Format.eprintf "%a: %s@." Span.pp_short span msg + | Lexer.Lexer_error (msg, pos) -> + Format.eprintf ":%d:%d: %s@." pos.Span.line pos.Span.col msg + | Eval.Runtime_error (msg, span_opt) -> + (match span_opt with + | Some span -> Format.eprintf "%a: %s@." Span.pp_short span msg + | None -> Format.eprintf "Error: %s@." msg) + | Failure msg -> + Format.eprintf "Error: %s@." msg + +(** Process a single line of input *) +let process_line state line = + match parse_command line with + | CmdHelp -> print_help (); true + | CmdQuit -> false + | CmdClear -> + Hashtbl.clear state.env.Value.bindings; + add_builtins state.env; + state.counter <- 0; + Format.printf "Environment cleared.@."; + true + | CmdEnv -> print_env state.env; true + | CmdDebug -> + state.debug <- not state.debug; + Format.printf "Debug mode: %s@." (if state.debug then "on" else "off"); + true + | CmdLoad file -> load_file state file; true + | CmdType input -> + (try + let expr = Parse_driver.parse_expr ~file:"" input in + Format.printf "%s@." (Ast.show_expr expr) + with + | Parse_driver.Parse_error (msg, span) -> + Format.eprintf "%a: %s@." Span.pp_short span msg + | Lexer.Lexer_error (msg, pos) -> + Format.eprintf ":%d:%d: %s@." pos.Span.line pos.Span.col msg); + true + | CmdAst input -> + (try + let prog = Parse_driver.parse_string ~file:"" input in + Format.printf "%s@." (Ast.show_program prog) + with + | Parse_driver.Parse_error (msg, span) -> + Format.eprintf "%a: %s@." Span.pp_short span msg + | Lexer.Lexer_error (msg, pos) -> + Format.eprintf ":%d:%d: %s@." pos.Span.line pos.Span.col msg); + true + | CmdNone -> true + | CmdInput input -> + (* Handle multiline input *) + if state.multiline then begin + state.buffer <- state.buffer ^ "\n" ^ input; + if is_complete state.buffer then begin + eval_input state state.buffer; + state.buffer <- ""; + state.multiline <- false + end + end else if not (is_complete input) then begin + state.buffer <- input; + state.multiline <- true + end else + eval_input state input; + true + +(** Main REPL loop *) +let run () = + Format.printf "@["; + Format.printf "AffineScript REPL v0.1.0@."; + Format.printf "Type :help for commands, :quit to exit@."; + Format.printf "@]@."; + + let state = create_state () in + add_builtins state.env; + + let rec loop () = + let prompt = if state.multiline then "... " else ">>> " in + Format.printf "%s@?" prompt; + match In_channel.input_line In_channel.stdin with + | None -> Format.printf "@.Goodbye!@." + | Some line -> + if process_line state line then loop () + else Format.printf "Goodbye!@." + in + loop () + +(** Run REPL with initial file *) +let run_with_file filename = + let state = create_state () in + add_builtins state.env; + load_file state filename; + Format.printf "@["; + Format.printf "AffineScript REPL v0.1.0 (loaded %s)@." filename; + Format.printf "Type :help for commands, :quit to exit@."; + Format.printf "@]@."; + + let rec loop () = + let prompt = if state.multiline then "... " else ">>> " in + Format.printf "%s@?" prompt; + match In_channel.input_line In_channel.stdin with + | None -> Format.printf "@.Goodbye!@." + | Some line -> + if process_line state line then loop () + else Format.printf "Goodbye!@." + in + loop () + +(** Evaluate a string and return the result (for testing) *) +let eval_string ?(env = Value.empty_env ()) input = + add_builtins env; + try + (* Try as expression first *) + let expr = Parse_driver.parse_expr ~file:"" input in + Ok (Eval.eval env expr) + with + | _ -> + (* Try as program *) + try + let prog = Parse_driver.parse_string ~file:"" input in + Eval.eval_program env prog; + Ok Value.VUnit + with + | Parse_driver.Parse_error (msg, _) -> Error ("Parse error: " ^ msg) + | Lexer.Lexer_error (msg, _) -> Error ("Lexer error: " ^ msg) + | Eval.Runtime_error (msg, _) -> Error ("Runtime error: " ^ msg) + | Failure msg -> Error msg diff --git a/lib/resolve.ml b/lib/resolve.ml new file mode 100644 index 0000000..be2e5f9 --- /dev/null +++ b/lib/resolve.ml @@ -0,0 +1,634 @@ +(** Name resolution pass for AffineScript *) + +open Ast +open Symbol + +(** Resolution errors *) +type error = + | UnboundVariable of string * Span.t + | UnboundType of string * Span.t + | UnboundModule of string list * Span.t + | DuplicateDefinition of string * Span.t * Span.t option + | InvalidBreak of Span.t + | InvalidContinue of Span.t + | InvalidReturn of Span.t + | PrivateAccess of string * Span.t + | CyclicDependency of string list + +exception Resolution_error of error + +let error_to_string = function + | UnboundVariable (name, _) -> + Printf.sprintf "Unbound variable: %s" name + | UnboundType (name, _) -> + Printf.sprintf "Unbound type: %s" name + | UnboundModule (path, _) -> + Printf.sprintf "Unbound module: %s" (String.concat "." path) + | DuplicateDefinition (name, _, _) -> + Printf.sprintf "Duplicate definition: %s" name + | InvalidBreak _ -> "break outside of loop" + | InvalidContinue _ -> "continue outside of loop" + | InvalidReturn _ -> "return outside of function" + | PrivateAccess (name, _) -> + Printf.sprintf "Cannot access private symbol: %s" name + | CyclicDependency path -> + Printf.sprintf "Cyclic dependency: %s" (String.concat " -> " path) + +(** Resolved identifier - original ident with symbol reference *) +type resolved_ident = { + ri_name: string; + ri_span: Span.t; + ri_symbol: symbol; +} +[@@deriving show] + +(** Resolution context *) +type context = { + ctx_symtab: symbol_table; + ctx_errors: error list ref; + ctx_current_module: module_path; + ctx_type_params: (string, symbol) Hashtbl.t; (** In-scope type parameters *) +} + +let create_context () = { + ctx_symtab = Symbol.create (); + ctx_errors = ref []; + ctx_current_module = []; + ctx_type_params = Hashtbl.create 8; +} + +let add_error ctx err = + ctx.ctx_errors := err :: !(ctx.ctx_errors) + +(** Convert AST visibility to Symbol visibility *) +let convert_visibility = function + | Ast.Private -> VisPrivate + | Ast.Public -> VisPublic + | Ast.PubCrate -> VisPubCrate + | Ast.PubSuper -> VisPubSuper + | Ast.PubIn path -> VisPubIn (List.map (fun id -> id.name) path) + +(** Resolve a variable reference *) +let resolve_var ctx (id: ident) = + match Symbol.lookup ctx.ctx_symtab id.name with + | Some sym -> Some { ri_name = id.name; ri_span = id.span; ri_symbol = sym } + | None -> + add_error ctx (UnboundVariable (id.name, id.span)); + None + +(** Resolve a type reference *) +let resolve_type_ref ctx (id: ident) = + (* First check type parameters *) + match Hashtbl.find_opt ctx.ctx_type_params id.name with + | Some sym -> Some sym + | None -> + (* Then check regular scope *) + match Symbol.lookup ctx.ctx_symtab id.name with + | Some sym when sym.sym_kind = SKType || sym.sym_kind = SKTypeParam -> + Some sym + | Some _ -> + add_error ctx (UnboundType (id.name, id.span)); + None + | None -> + add_error ctx (UnboundType (id.name, id.span)); + None + +(** Define a variable in current scope *) +let define_var ctx ~name ~span ~mutable_ ~ty ~quantity = + if Symbol.is_defined_local ctx.ctx_symtab name then begin + let existing = Symbol.find_local (Symbol.current_scope ctx.ctx_symtab) name in + let prev_span = Option.bind existing (fun s -> s.sym_span) in + add_error ctx (DuplicateDefinition (name, span, prev_span)) + end; + let sym = Symbol.make_symbol + ~name + ~kind:SKVariable + ~span + ~mutable_ + ?ty + ?quantity + () + in + Symbol.register ctx.ctx_symtab sym + +(** Define a function *) +let define_function ctx ~name ~span ~vis ~ty = + if Symbol.is_defined_local ctx.ctx_symtab name then begin + let existing = Symbol.find_local (Symbol.current_scope ctx.ctx_symtab) name in + let prev_span = Option.bind existing (fun s -> s.sym_span) in + add_error ctx (DuplicateDefinition (name, span, prev_span)) + end; + let sym = Symbol.make_symbol + ~name + ~kind:SKFunction + ~span + ~vis:(convert_visibility vis) + ?ty + () + in + Symbol.register ctx.ctx_symtab sym + +(** Define a type *) +let define_type ctx ~name ~span ~vis = + if Symbol.is_defined_local ctx.ctx_symtab name then begin + let existing = Symbol.find_local (Symbol.current_scope ctx.ctx_symtab) name in + let prev_span = Option.bind existing (fun s -> s.sym_span) in + add_error ctx (DuplicateDefinition (name, span, prev_span)) + end; + let sym = Symbol.make_symbol + ~name + ~kind:SKType + ~span + ~vis:(convert_visibility vis) + () + in + Symbol.register ctx.ctx_symtab sym + +(** Define a type parameter *) +let define_type_param ctx (tp: type_param) = + let sym = Symbol.make_symbol + ~name:tp.tp_name.name + ~kind:SKTypeParam + ~span:tp.tp_name.span + ?quantity:tp.tp_quantity + () + in + Hashtbl.replace ctx.ctx_type_params tp.tp_name.name sym; + Symbol.register ctx.ctx_symtab sym + +(** Clear type parameters after leaving scope *) +let clear_type_params ctx = + Hashtbl.clear ctx.ctx_type_params + +(** Enter scope helper *) +let enter_scope ctx kind = + ignore (Symbol.enter_scope ctx.ctx_symtab kind) + +(** Exit scope helper *) +let exit_scope ctx = + Symbol.exit_scope ctx.ctx_symtab + +(** Resolve a pattern, defining bound variables *) +let rec resolve_pattern ctx pat = + match pat with + | PatWildcard _ -> () + | PatVar id -> + ignore (define_var ctx + ~name:id.name + ~span:id.span + ~mutable_:false + ~ty:None + ~quantity:None) + | PatLit _ -> () + | PatCon (id, pats) -> + (* Resolve constructor reference *) + (match Symbol.lookup ctx.ctx_symtab id.name with + | Some sym when sym.sym_kind = SKVariant -> () + | _ -> add_error ctx (UnboundVariable (id.name, id.span))); + List.iter (resolve_pattern ctx) pats + | PatTuple pats -> + List.iter (resolve_pattern ctx) pats + | PatRecord (fields, _) -> + List.iter (fun (_, pat_opt) -> + Option.iter (resolve_pattern ctx) pat_opt + ) fields + | PatOr (p1, p2) -> + (* Both branches should bind same names *) + resolve_pattern ctx p1; + resolve_pattern ctx p2 + | PatAs (id, p) -> + ignore (define_var ctx + ~name:id.name + ~span:id.span + ~mutable_:false + ~ty:None + ~quantity:None); + resolve_pattern ctx p + +(** Resolve a type expression *) +let rec resolve_type_expr ctx ty = + match ty with + | TyVar id | TyCon id -> + ignore (resolve_type_ref ctx id) + | TyApp (id, args) -> + ignore (resolve_type_ref ctx id); + List.iter (resolve_type_arg ctx) args + | TyArrow (t1, t2, eff) -> + resolve_type_expr ctx t1; + resolve_type_expr ctx t2; + Option.iter (resolve_effect_expr ctx) eff + | TyDepArrow { da_param; da_param_ty; da_ret_ty; da_eff; _ } -> + resolve_type_expr ctx da_param_ty; + enter_scope ctx ScopeBlock; + ignore (define_var ctx + ~name:da_param.name + ~span:da_param.span + ~mutable_:false + ~ty:(Some da_param_ty) + ~quantity:None); + resolve_type_expr ctx da_ret_ty; + Option.iter (resolve_effect_expr ctx) da_eff; + exit_scope ctx + | TyTuple tys -> + List.iter (resolve_type_expr ctx) tys + | TyRecord (fields, rest) -> + List.iter (fun rf -> resolve_type_expr ctx rf.rf_ty) fields; + Option.iter (fun id -> ignore (resolve_type_ref ctx id)) rest + | TyOwn t | TyRef t | TyMut t -> + resolve_type_expr ctx t + | TyRefined (t, _pred) -> + resolve_type_expr ctx t + (* TODO: resolve predicate names *) + | TyHole -> () + +and resolve_type_arg ctx = function + | TyArg t -> resolve_type_expr ctx t + | NatArg _ -> () (* TODO: resolve nat expressions *) + +and resolve_effect_expr ctx = function + | EffVar id -> ignore (resolve_type_ref ctx id) + | EffCon (id, args) -> + ignore (resolve_type_ref ctx id); + List.iter (resolve_type_arg ctx) args + | EffUnion (e1, e2) -> + resolve_effect_expr ctx e1; + resolve_effect_expr ctx e2 + +(** Resolve an expression *) +let rec resolve_expr ctx expr = + match expr with + | ExprSpan (e, _) -> resolve_expr ctx e + | ExprLit _ -> () + | ExprVar id -> + ignore (resolve_var ctx id) + | ExprLet { el_mut; el_pat; el_ty; el_value; el_body } -> + resolve_expr ctx el_value; + Option.iter (resolve_type_expr ctx) el_ty; + resolve_pattern ctx el_pat; + Option.iter (resolve_expr ctx) el_body + | ExprIf { ei_cond; ei_then; ei_else } -> + resolve_expr ctx ei_cond; + resolve_expr ctx ei_then; + Option.iter (resolve_expr ctx) ei_else + | ExprMatch { em_scrutinee; em_arms } -> + resolve_expr ctx em_scrutinee; + List.iter (resolve_match_arm ctx) em_arms + | ExprLambda { elam_params; elam_ret_ty; elam_body } -> + enter_scope ctx ScopeFunction; + List.iter (resolve_param ctx) elam_params; + Option.iter (resolve_type_expr ctx) elam_ret_ty; + resolve_expr ctx elam_body; + exit_scope ctx + | ExprApp (func, args) -> + resolve_expr ctx func; + List.iter (resolve_expr ctx) args + | ExprField (e, _) -> + resolve_expr ctx e + | ExprTupleIndex (e, _) -> + resolve_expr ctx e + | ExprIndex (e1, e2) -> + resolve_expr ctx e1; + resolve_expr ctx e2 + | ExprTuple exprs | ExprArray exprs -> + List.iter (resolve_expr ctx) exprs + | ExprRecord { er_fields; er_spread } -> + List.iter (fun (_, expr_opt) -> + Option.iter (resolve_expr ctx) expr_opt + ) er_fields; + Option.iter (resolve_expr ctx) er_spread + | ExprRowRestrict (e, _) -> + resolve_expr ctx e + | ExprBinary (e1, _, e2) -> + resolve_expr ctx e1; + resolve_expr ctx e2 + | ExprUnary (_, e) -> + resolve_expr ctx e + | ExprBlock blk -> + resolve_block ctx blk + | ExprReturn e_opt -> + if not (Symbol.in_function (Symbol.current_scope ctx.ctx_symtab)) then + add_error ctx (InvalidReturn (Span.dummy)); (* TODO: get span *) + Option.iter (resolve_expr ctx) e_opt + | ExprTry { et_body; et_catch; et_finally } -> + resolve_block ctx et_body; + Option.iter (List.iter (resolve_match_arm ctx)) et_catch; + Option.iter (resolve_block ctx) et_finally + | ExprHandle { eh_body; eh_handlers } -> + resolve_expr ctx eh_body; + List.iter (resolve_handler_arm ctx) eh_handlers + | ExprResume e_opt -> + Option.iter (resolve_expr ctx) e_opt + | ExprUnsafe ops -> + List.iter (resolve_unsafe_op ctx) ops + | ExprVariant (type_id, _variant_id) -> + ignore (resolve_type_ref ctx type_id) + +and resolve_match_arm ctx arm = + enter_scope ctx ScopeMatch; + resolve_pattern ctx arm.ma_pat; + Option.iter (resolve_expr ctx) arm.ma_guard; + resolve_expr ctx arm.ma_body; + exit_scope ctx + +and resolve_handler_arm ctx = function + | HandlerReturn (pat, e) -> + enter_scope ctx ScopeMatch; + resolve_pattern ctx pat; + resolve_expr ctx e; + exit_scope ctx + | HandlerOp (_, pats, e) -> + enter_scope ctx ScopeMatch; + List.iter (resolve_pattern ctx) pats; + resolve_expr ctx e; + exit_scope ctx + +and resolve_unsafe_op ctx = function + | UnsafeRead e | UnsafeForget e -> resolve_expr ctx e + | UnsafeWrite (e1, e2) | UnsafeOffset (e1, e2) -> + resolve_expr ctx e1; + resolve_expr ctx e2 + | UnsafeTransmute (t1, t2, e) -> + resolve_type_expr ctx t1; + resolve_type_expr ctx t2; + resolve_expr ctx e + | UnsafeAssume _ -> () + +and resolve_block ctx { blk_stmts; blk_expr } = + enter_scope ctx ScopeBlock; + List.iter (resolve_stmt ctx) blk_stmts; + Option.iter (resolve_expr ctx) blk_expr; + exit_scope ctx + +and resolve_stmt ctx = function + | StmtLet { sl_mut; sl_pat; sl_ty; sl_value } -> + resolve_expr ctx sl_value; + Option.iter (resolve_type_expr ctx) sl_ty; + resolve_pattern ctx sl_pat + | StmtExpr e -> + resolve_expr ctx e + | StmtAssign (target, _, value) -> + resolve_expr ctx target; + resolve_expr ctx value + | StmtWhile (cond, body) -> + resolve_expr ctx cond; + enter_scope ctx ScopeLoop; + resolve_block ctx body; + exit_scope ctx + | StmtFor (pat, iter, body) -> + resolve_expr ctx iter; + enter_scope ctx ScopeLoop; + resolve_pattern ctx pat; + resolve_block ctx body; + exit_scope ctx + +and resolve_param ctx (p: param) = + resolve_type_expr ctx p.p_ty; + ignore (define_var ctx + ~name:p.p_name.name + ~span:p.p_name.span + ~mutable_:false + ~ty:(Some p.p_ty) + ~quantity:p.p_quantity) + +(** Resolve a function declaration *) +let resolve_fn_decl ctx (decl: fn_decl) = + (* Define the function first (for recursion) *) + let fn_sym = define_function ctx + ~name:decl.fd_name.name + ~span:decl.fd_name.span + ~vis:decl.fd_vis + ~ty:None + in + (* Enter function scope *) + enter_scope ctx ScopeFunction; + (* Define type parameters *) + List.iter (define_type_param ctx) decl.fd_type_params; + (* Resolve parameters *) + List.iter (resolve_param ctx) decl.fd_params; + (* Resolve return type *) + Option.iter (resolve_type_expr ctx) decl.fd_ret_ty; + (* Resolve effect annotation *) + Option.iter (resolve_effect_expr ctx) decl.fd_eff; + (* Resolve body *) + (match decl.fd_body with + | FnBlock blk -> resolve_block ctx blk + | FnExpr e -> resolve_expr ctx e); + (* Exit scope *) + exit_scope ctx; + clear_type_params ctx; + fn_sym + +(** Resolve a type declaration *) +let resolve_type_decl ctx (decl: type_decl) = + (* Define the type *) + let type_sym = define_type ctx + ~name:decl.td_name.name + ~span:decl.td_name.span + ~vis:decl.td_vis + in + (* Enter scope for type parameters *) + enter_scope ctx ScopeBlock; + List.iter (define_type_param ctx) decl.td_type_params; + (* Resolve type body *) + (match decl.td_body with + | TyAlias ty -> + resolve_type_expr ctx ty + | TyStruct fields -> + List.iter (fun sf -> resolve_type_expr ctx sf.sf_ty) fields + | TyEnum variants -> + List.iter (fun vd -> + (* Define variant constructor *) + let _ = Symbol.make_symbol + ~name:vd.vd_name.name + ~kind:SKVariant + ~span:vd.vd_name.span + () + in + (* Register at global scope for constructor access *) + List.iter (resolve_type_expr ctx) vd.vd_fields; + Option.iter (resolve_type_expr ctx) vd.vd_ret_ty + ) variants); + exit_scope ctx; + clear_type_params ctx; + type_sym + +(** Resolve an effect declaration *) +let resolve_effect_decl ctx (decl: effect_decl) = + let effect_sym = Symbol.make_symbol + ~name:decl.ed_name.name + ~kind:SKEffect + ~span:decl.ed_name.span + ~vis:(convert_visibility decl.ed_vis) + () + in + ignore (Symbol.register ctx.ctx_symtab effect_sym); + enter_scope ctx ScopeBlock; + List.iter (define_type_param ctx) decl.ed_type_params; + List.iter (fun op -> + List.iter (resolve_param ctx) op.eod_params; + Option.iter (resolve_type_expr ctx) op.eod_ret_ty + ) decl.ed_ops; + exit_scope ctx; + clear_type_params ctx; + effect_sym + +(** Resolve a trait declaration *) +let resolve_trait_decl ctx (decl: trait_decl) = + let trait_sym = Symbol.make_symbol + ~name:decl.trd_name.name + ~kind:SKTrait + ~span:decl.trd_name.span + ~vis:(convert_visibility decl.trd_vis) + () + in + ignore (Symbol.register ctx.ctx_symtab trait_sym); + enter_scope ctx ScopeTrait; + List.iter (define_type_param ctx) decl.trd_type_params; + (* Resolve supertraits *) + List.iter (fun tb -> + ignore (resolve_type_ref ctx tb.tb_name); + List.iter (resolve_type_arg ctx) tb.tb_args + ) decl.trd_super; + (* Resolve trait items *) + List.iter (fun item -> + match item with + | TraitFn sig_ -> + List.iter (define_type_param ctx) sig_.fs_type_params; + List.iter (resolve_param ctx) sig_.fs_params; + Option.iter (resolve_type_expr ctx) sig_.fs_ret_ty; + Option.iter (resolve_effect_expr ctx) sig_.fs_eff + | TraitFnDefault decl -> + ignore (resolve_fn_decl ctx decl) + | TraitType { tt_name; tt_default; _ } -> + ignore (define_type ctx ~name:tt_name.name ~span:tt_name.span ~vis:Private); + Option.iter (resolve_type_expr ctx) tt_default + ) decl.trd_items; + exit_scope ctx; + clear_type_params ctx; + trait_sym + +(** Resolve an impl block *) +let resolve_impl_block ctx (impl: impl_block) = + enter_scope ctx ScopeImpl; + List.iter (define_type_param ctx) impl.ib_type_params; + (* Resolve trait reference if present *) + Option.iter (fun tr -> + ignore (resolve_type_ref ctx tr.tr_name); + List.iter (resolve_type_arg ctx) tr.tr_args + ) impl.ib_trait_ref; + (* Resolve self type *) + resolve_type_expr ctx impl.ib_self_ty; + (* Resolve impl items *) + List.iter (fun item -> + match item with + | ImplFn decl -> ignore (resolve_fn_decl ctx decl) + | ImplType (name, ty) -> + ignore (define_type ctx ~name:name.name ~span:name.span ~vis:Private); + resolve_type_expr ctx ty + ) impl.ib_items; + exit_scope ctx; + clear_type_params ctx + +(** Resolve an import declaration *) +let resolve_import _ctx _import = + (* TODO: Module system implementation *) + () + +(** Resolve a top-level declaration *) +let resolve_top_level ctx = function + | TopFn decl -> ignore (resolve_fn_decl ctx decl) + | TopType decl -> ignore (resolve_type_decl ctx decl) + | TopEffect decl -> ignore (resolve_effect_decl ctx decl) + | TopTrait decl -> ignore (resolve_trait_decl ctx decl) + | TopImpl impl -> resolve_impl_block ctx impl + | TopConst { tc_vis; tc_name; tc_ty; tc_value } -> + resolve_type_expr ctx tc_ty; + resolve_expr ctx tc_value; + ignore (define_var ctx + ~name:tc_name.name + ~span:tc_name.span + ~mutable_:false + ~ty:(Some tc_ty) + ~quantity:None) + +(** Add built-in symbols to context *) +let add_builtins ctx = + let add_builtin name kind = + let sym = Symbol.make_symbol ~name ~kind ~vis:VisPublic () in + ignore (Symbol.register ctx.ctx_symtab sym) + in + (* Built-in types *) + add_builtin "Int" SKType; + add_builtin "Float" SKType; + add_builtin "Bool" SKType; + add_builtin "Char" SKType; + add_builtin "String" SKType; + add_builtin "Unit" SKType; + add_builtin "Never" SKType; + add_builtin "Array" SKType; + add_builtin "Nat" SKType; + (* Built-in functions *) + add_builtin "print" SKBuiltin; + add_builtin "println" SKBuiltin; + add_builtin "str" SKBuiltin; + add_builtin "len" SKBuiltin; + add_builtin "type_of" SKBuiltin; + add_builtin "range" SKBuiltin; + add_builtin "push" SKBuiltin; + add_builtin "head" SKBuiltin; + add_builtin "tail" SKBuiltin; + add_builtin "map" SKBuiltin; + add_builtin "filter" SKBuiltin; + add_builtin "fold" SKBuiltin; + add_builtin "assert" SKBuiltin; + add_builtin "panic" SKBuiltin + +(** Resolve a complete program *) +let resolve_program (prog: program) = + let ctx = create_context () in + add_builtins ctx; + (* First pass: collect all top-level definitions *) + (* (This allows forward references) *) + List.iter (fun decl -> + match decl with + | TopFn fd -> + ignore (define_function ctx + ~name:fd.fd_name.name + ~span:fd.fd_name.span + ~vis:fd.fd_vis + ~ty:None) + | TopType td -> + ignore (define_type ctx + ~name:td.td_name.name + ~span:td.td_name.span + ~vis:td.td_vis) + | TopEffect ed -> + let sym = Symbol.make_symbol + ~name:ed.ed_name.name + ~kind:SKEffect + ~span:ed.ed_name.span + () + in + ignore (Symbol.register ctx.ctx_symtab sym) + | TopTrait td -> + let sym = Symbol.make_symbol + ~name:td.trd_name.name + ~kind:SKTrait + ~span:td.trd_name.span + () + in + ignore (Symbol.register ctx.ctx_symtab sym) + | TopImpl _ -> () + | TopConst { tc_name; _ } -> + ignore (define_var ctx + ~name:tc_name.name + ~span:tc_name.span + ~mutable_:false + ~ty:None + ~quantity:None) + ) prog.prog_decls; + (* Second pass: resolve all definitions *) + List.iter (resolve_import ctx) prog.prog_imports; + List.iter (resolve_top_level ctx) prog.prog_decls; + (* Return context with resolved symbols and any errors *) + (ctx.ctx_symtab, List.rev !(ctx.ctx_errors)) diff --git a/lib/stdlib.ml b/lib/stdlib.ml new file mode 100644 index 0000000..8dbe441 --- /dev/null +++ b/lib/stdlib.ml @@ -0,0 +1,733 @@ +(** AffineScript Standard Library *) + +(** Prelude - functions automatically available *) +let prelude_source = {| +// Core type aliases +type Unit = (); +type Never = !; + +// Option type +enum Option[T] { + None, + Some(T), +} + +// Result type +enum Result[T, E] { + Ok(T), + Err(E), +} + +// Ordering for comparisons +enum Ordering { + Less, + Equal, + Greater, +} + +// Either type +enum Either[L, R] { + Left(L), + Right(R), +} + +// Pair type +struct Pair[A, B] { + fst: A, + snd: B, +} +|} + +(** Math module *) +let math_source = {| +module Std.Math; + +// Constants +let PI: Float = 3.141592653589793; +let E: Float = 2.718281828459045; +let TAU: Float = 6.283185307179586; + +// Absolute value +fn abs(x: Int) -> Int { + if x < 0 { -x } else { x } +} + +fn fabs(x: Float) -> Float { + if x < 0.0 { -x } else { x } +} + +// Min/max +fn min(a: Int, b: Int) -> Int { + if a < b { a } else { b } +} + +fn max(a: Int, b: Int) -> Int { + if a > b { a } else { b } +} + +fn fmin(a: Float, b: Float) -> Float { + if a < b { a } else { b } +} + +fn fmax(a: Float, b: Float) -> Float { + if a > b { a } else { b } +} + +// Clamp +fn clamp(x: Int, lo: Int, hi: Int) -> Int { + if x < lo { lo } + else if x > hi { hi } + else { x } +} + +// Sign +fn sign(x: Int) -> Int { + if x < 0 { -1 } + else if x > 0 { 1 } + else { 0 } +} + +// Power (integer) +fn pow(base: Int, exp: Int) -> Int { + if exp == 0 { 1 } + else if exp == 1 { base } + else { + let half = pow(base, exp / 2); + if exp % 2 == 0 { half * half } + else { half * half * base } + } +} + +// GCD using Euclidean algorithm +fn gcd(a: Int, b: Int) -> Int { + let x = abs(a); + let y = abs(b); + if y == 0 { x } + else { gcd(y, x % y) } +} + +// LCM +fn lcm(a: Int, b: Int) -> Int { + if a == 0 || b == 0 { 0 } + else { abs(a * b) / gcd(a, b) } +} + +// Factorial +fn factorial(n: Int) -> Int { + if n <= 1 { 1 } + else { n * factorial(n - 1) } +} + +// Fibonacci +fn fib(n: Int) -> Int { + if n <= 1 { n } + else { fib(n - 1) + fib(n - 2) } +} + +// Is prime check +fn is_prime(n: Int) -> Bool { + if n < 2 { false } + else if n == 2 { true } + else if n % 2 == 0 { false } + else { + let mut i = 3; + let mut result = true; + while i * i <= n && result { + if n % i == 0 { result = false; } + i += 2; + } + result + } +} +|} + +(** String module *) +let string_source = {| +module Std.String; + +// Check if string is empty +fn is_empty(s: String) -> Bool { + len(s) == 0 +} + +// Repeat a string n times +fn repeat(s: String, n: Int) -> String { + if n <= 0 { "" } + else if n == 1 { s } + else { + let mut result = ""; + let mut i = 0; + while i < n { + result = result + s; + i += 1; + } + result + } +} + +// Check if string starts with prefix +fn starts_with(s: String, prefix: String) -> Bool { + let slen = len(s); + let plen = len(prefix); + if plen > slen { false } + else { + let mut i = 0; + let mut result = true; + while i < plen && result { + if s[i] != prefix[i] { result = false; } + i += 1; + } + result + } +} + +// Check if string ends with suffix +fn ends_with(s: String, suffix: String) -> Bool { + let slen = len(s); + let suflen = len(suffix); + if suflen > slen { false } + else { + let mut i = 0; + let offset = slen - suflen; + let mut result = true; + while i < suflen && result { + if s[offset + i] != suffix[i] { result = false; } + i += 1; + } + result + } +} + +// Reverse a string +fn reverse(s: String) -> String { + let n = len(s); + if n <= 1 { s } + else { + let mut result = ""; + let mut i = n - 1; + while i >= 0 { + result = result + str(s[i]); + i -= 1; + } + result + } +} + +// Check if character is digit +fn is_digit(c: Char) -> Bool { + c >= '0' && c <= '9' +} + +// Check if character is letter +fn is_alpha(c: Char) -> Bool { + (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') +} + +// Check if character is alphanumeric +fn is_alnum(c: Char) -> Bool { + is_digit(c) || is_alpha(c) +} + +// Check if character is whitespace +fn is_space(c: Char) -> Bool { + c == ' ' || c == '\t' || c == '\n' || c == '\r' +} + +// Convert character to uppercase +fn to_upper(c: Char) -> Char { + if c >= 'a' && c <= 'z' { + // ASCII math: 'a' - 'A' = 32 + c - 32 + } else { c } +} + +// Convert character to lowercase +fn to_lower(c: Char) -> Char { + if c >= 'A' && c <= 'Z' { + c + 32 + } else { c } +} +|} + +(** Array/List module *) +let array_source = {| +module Std.Array; + +// Check if array is empty +fn is_empty[T](arr: Array[T]) -> Bool { + len(arr) == 0 +} + +// Get first element +fn first[T](arr: Array[T]) -> Option[T] { + if len(arr) == 0 { None } + else { Some(arr[0]) } +} + +// Get last element +fn last[T](arr: Array[T]) -> Option[T] { + let n = len(arr); + if n == 0 { None } + else { Some(arr[n - 1]) } +} + +// Sum of integer array +fn sum(arr: Array[Int]) -> Int { + fold(\acc, x -> acc + x, 0, arr) +} + +// Product of integer array +fn product(arr: Array[Int]) -> Int { + fold(\acc, x -> acc * x, 1, arr) +} + +// Find minimum +fn min(arr: Array[Int]) -> Option[Int] { + if len(arr) == 0 { None } + else { + let mut result = arr[0]; + let mut i = 1; + while i < len(arr) { + if arr[i] < result { result = arr[i]; } + i += 1; + } + Some(result) + } +} + +// Find maximum +fn max(arr: Array[Int]) -> Option[Int] { + if len(arr) == 0 { None } + else { + let mut result = arr[0]; + let mut i = 1; + while i < len(arr) { + if arr[i] > result { result = arr[i]; } + i += 1; + } + Some(result) + } +} + +// Check if any element satisfies predicate +fn any[T](pred: fn(T) -> Bool, arr: Array[T]) -> Bool { + let mut i = 0; + let mut found = false; + while i < len(arr) && !found { + if pred(arr[i]) { found = true; } + i += 1; + } + found +} + +// Check if all elements satisfy predicate +fn all[T](pred: fn(T) -> Bool, arr: Array[T]) -> Bool { + let mut i = 0; + let mut result = true; + while i < len(arr) && result { + if !pred(arr[i]) { result = false; } + i += 1; + } + result +} + +// Find first element satisfying predicate +fn find[T](pred: fn(T) -> Bool, arr: Array[T]) -> Option[T] { + let mut i = 0; + while i < len(arr) { + if pred(arr[i]) { return Some(arr[i]); } + i += 1; + } + None +} + +// Find index of first element satisfying predicate +fn find_index[T](pred: fn(T) -> Bool, arr: Array[T]) -> Option[Int] { + let mut i = 0; + while i < len(arr) { + if pred(arr[i]) { return Some(i); } + i += 1; + } + None +} + +// Reverse an array (returns new array) +fn reverse[T](arr: Array[T]) -> Array[T] { + let n = len(arr); + if n <= 1 { arr } + else { + let mut result = []; + let mut i = n - 1; + while i >= 0 { + result = push(result, arr[i]); + i -= 1; + } + result + } +} + +// Take first n elements +fn take[T](n: Int, arr: Array[T]) -> Array[T] { + let count = if n > len(arr) { len(arr) } else { n }; + let mut result = []; + let mut i = 0; + while i < count { + result = push(result, arr[i]); + i += 1; + } + result +} + +// Drop first n elements +fn drop[T](n: Int, arr: Array[T]) -> Array[T] { + let start = if n > len(arr) { len(arr) } else { n }; + let mut result = []; + let mut i = start; + while i < len(arr) { + result = push(result, arr[i]); + i += 1; + } + result +} + +// Zip two arrays together +fn zip[A, B](a: Array[A], b: Array[B]) -> Array[(A, B)] { + let n = if len(a) < len(b) { len(a) } else { len(b) }; + let mut result = []; + let mut i = 0; + while i < n { + result = push(result, (a[i], b[i])); + i += 1; + } + result +} + +// Enumerate with indices +fn enumerate[T](arr: Array[T]) -> Array[(Int, T)] { + let mut result = []; + let mut i = 0; + while i < len(arr) { + result = push(result, (i, arr[i])); + i += 1; + } + result +} + +// Flatten nested arrays +fn flatten[T](arrs: Array[Array[T]]) -> Array[T] { + let mut result = []; + let mut i = 0; + while i < len(arrs) { + let mut j = 0; + while j < len(arrs[i]) { + result = push(result, arrs[i][j]); + j += 1; + } + i += 1; + } + result +} + +// Flat map +fn flat_map[T, U](f: fn(T) -> Array[U], arr: Array[T]) -> Array[U] { + flatten(map(f, arr)) +} + +// Sort array (simple insertion sort for now) +fn sort(arr: Array[Int]) -> Array[Int] { + let n = len(arr); + if n <= 1 { arr } + else { + let mut result = arr; + let mut i = 1; + while i < n { + let key = result[i]; + let mut j = i - 1; + while j >= 0 && result[j] > key { + result[j + 1] = result[j]; + j -= 1; + } + result[j + 1] = key; + i += 1; + } + result + } +} +|} + +(** Option module *) +let option_source = {| +module Std.Option; + +// Check if option is Some +fn is_some[T](opt: Option[T]) -> Bool { + match opt { + Some(_) => true, + None => false, + } +} + +// Check if option is None +fn is_none[T](opt: Option[T]) -> Bool { + match opt { + None => true, + Some(_) => false, + } +} + +// Unwrap with default +fn unwrap_or[T](opt: Option[T], default: T) -> T { + match opt { + Some(x) => x, + None => default, + } +} + +// Unwrap or panic +fn unwrap[T](opt: Option[T]) -> T { + match opt { + Some(x) => x, + None => panic("called unwrap on None"), + } +} + +// Map over option +fn map[T, U](f: fn(T) -> U, opt: Option[T]) -> Option[U] { + match opt { + Some(x) => Some(f(x)), + None => None, + } +} + +// Flat map / and_then +fn and_then[T, U](f: fn(T) -> Option[U], opt: Option[T]) -> Option[U] { + match opt { + Some(x) => f(x), + None => None, + } +} + +// Or else +fn or_else[T](opt: Option[T], alternative: Option[T]) -> Option[T] { + match opt { + Some(_) => opt, + None => alternative, + } +} + +// Filter option +fn filter[T](pred: fn(T) -> Bool, opt: Option[T]) -> Option[T] { + match opt { + Some(x) => if pred(x) { Some(x) } else { None }, + None => None, + } +} + +// Convert to array +fn to_array[T](opt: Option[T]) -> Array[T] { + match opt { + Some(x) => [x], + None => [], + } +} + +// Zip two options +fn zip[A, B](a: Option[A], b: Option[B]) -> Option[(A, B)] { + match (a, b) { + (Some(x), Some(y)) => Some((x, y)), + _ => None, + } +} +|} + +(** Result module *) +let result_source = {| +module Std.Result; + +// Check if result is Ok +fn is_ok[T, E](res: Result[T, E]) -> Bool { + match res { + Ok(_) => true, + Err(_) => false, + } +} + +// Check if result is Err +fn is_err[T, E](res: Result[T, E]) -> Bool { + match res { + Err(_) => true, + Ok(_) => false, + } +} + +// Unwrap Ok value or panic +fn unwrap[T, E](res: Result[T, E]) -> T { + match res { + Ok(x) => x, + Err(e) => panic("called unwrap on Err"), + } +} + +// Unwrap Err value or panic +fn unwrap_err[T, E](res: Result[T, E]) -> E { + match res { + Err(e) => e, + Ok(_) => panic("called unwrap_err on Ok"), + } +} + +// Unwrap with default +fn unwrap_or[T, E](res: Result[T, E], default: T) -> T { + match res { + Ok(x) => x, + Err(_) => default, + } +} + +// Map over Ok value +fn map[T, U, E](f: fn(T) -> U, res: Result[T, E]) -> Result[U, E] { + match res { + Ok(x) => Ok(f(x)), + Err(e) => Err(e), + } +} + +// Map over Err value +fn map_err[T, E, F](f: fn(E) -> F, res: Result[T, E]) -> Result[T, F] { + match res { + Ok(x) => Ok(x), + Err(e) => Err(f(e)), + } +} + +// Flat map / and_then +fn and_then[T, U, E](f: fn(T) -> Result[U, E], res: Result[T, E]) -> Result[U, E] { + match res { + Ok(x) => f(x), + Err(e) => Err(e), + } +} + +// Or else +fn or_else[T, E, F](f: fn(E) -> Result[T, F], res: Result[T, E]) -> Result[T, F] { + match res { + Ok(x) => Ok(x), + Err(e) => f(e), + } +} + +// Convert to Option (discards error) +fn ok[T, E](res: Result[T, E]) -> Option[T] { + match res { + Ok(x) => Some(x), + Err(_) => None, + } +} + +// Convert to Option (discards ok) +fn err[T, E](res: Result[T, E]) -> Option[E] { + match res { + Ok(_) => None, + Err(e) => Some(e), + } +} +|} + +(** IO module placeholders *) +let io_source = {| +module Std.IO; + +// Note: These are effect signatures, actual implementation requires runtime support + +// Read a line from stdin +// effect fn read_line() -> String / IO; + +// Write to stdout +// effect fn write(s: String) -> () / IO; + +// Write line to stdout +// effect fn write_line(s: String) -> () / IO; + +// Read entire file +// effect fn read_file(path: String) -> Result[String, IOError] / IO; + +// Write to file +// effect fn write_file(path: String, content: String) -> Result[(), IOError] / IO; + +// IO error type +enum IOError { + NotFound(String), + PermissionDenied(String), + Other(String), +} +|} + +(** Concurrency module placeholders *) +let async_source = {| +module Std.Async; + +// Note: Async requires effect system and runtime support + +// Future type (placeholder) +// type Future[T] = ... + +// Spawn async task +// effect fn spawn[T](f: fn() -> T) -> Future[T] / Async; + +// Await a future +// effect fn await[T](fut: Future[T]) -> T / Async; + +// Sleep for milliseconds +// effect fn sleep(ms: Int) -> () / Async; + +// Channel types (placeholder) +// struct Sender[T] { ... } +// struct Receiver[T] { ... } + +// Create a channel +// fn channel[T]() -> (Sender[T], Receiver[T]); + +// Send on channel +// effect fn send[T](sender: Sender[T], value: T) -> () / Async; + +// Receive from channel +// effect fn recv[T](receiver: Receiver[T]) -> Option[T] / Async; +|} + +(** Collect all stdlib modules *) +let modules = [ + ("Prelude", prelude_source); + ("Std.Math", math_source); + ("Std.String", string_source); + ("Std.Array", array_source); + ("Std.Option", option_source); + ("Std.Result", result_source); + ("Std.IO", io_source); + ("Std.Async", async_source); +] + +(** Load prelude into environment *) +let load_prelude env = + try + let prog = Parse_driver.parse_string ~file:"" prelude_source in + Eval.eval_program env prog + with _ -> + (* Prelude parse failed - just continue, types will be available when type checker exists *) + () + +(** Load a specific module *) +let load_module env name = + match List.assoc_opt name modules with + | Some source -> + (try + let prog = Parse_driver.parse_string ~file:name source in + Eval.eval_program env prog; + Ok () + with + | Parse_driver.Parse_error (msg, _) -> Error ("Parse error: " ^ msg) + | Eval.Runtime_error (msg, _) -> Error ("Runtime error: " ^ msg)) + | None -> Error (Printf.sprintf "Unknown module: %s" name) + +(** List available modules *) +let available_modules () = + List.map fst modules diff --git a/lib/symbol.ml b/lib/symbol.ml new file mode 100644 index 0000000..54812c8 --- /dev/null +++ b/lib/symbol.ml @@ -0,0 +1,215 @@ +(** Symbol table types for AffineScript name resolution *) + +(** Unique identifier for symbols *) +type symbol_id = int +[@@deriving show, eq, ord] + +let next_id = ref 0 +let fresh_id () = + let id = !next_id in + incr next_id; + id + +(** Symbol kinds *) +type symbol_kind = + | SKVariable (** Local/global variable *) + | SKFunction (** Function definition *) + | SKParameter (** Function parameter *) + | SKTypeVar (** Type variable *) + | SKType (** Type definition (struct, enum, alias) *) + | SKTypeParam (** Type parameter in generic *) + | SKTrait (** Trait definition *) + | SKEffect (** Effect definition *) + | SKEffectOp (** Effect operation *) + | SKVariant (** Enum variant constructor *) + | SKField (** Struct/record field *) + | SKModule (** Module *) + | SKBuiltin (** Built-in symbol *) +[@@deriving show, eq] + +(** Visibility of a symbol *) +type visibility = + | VisPrivate + | VisPublic + | VisPubCrate + | VisPubSuper + | VisPubIn of string list (** pub(in path) *) +[@@deriving show, eq] + +(** Symbol information *) +type symbol = { + sym_id: symbol_id; + sym_name: string; + sym_kind: symbol_kind; + sym_span: Span.t option; + sym_visibility: visibility; + sym_mutable: bool; + sym_type: Ast.type_expr option; (** Type annotation if available *) + sym_quantity: Ast.quantity option; (** Quantity annotation (0, 1, ω) *) +} +[@@deriving show] + +(** Create a new symbol *) +let make_symbol ~name ~kind ?span ?(vis=VisPrivate) ?(mutable_=false) ?ty ?quantity () = { + sym_id = fresh_id (); + sym_name = name; + sym_kind = kind; + sym_span = span; + sym_visibility = vis; + sym_mutable = mutable_; + sym_type = ty; + sym_quantity = quantity; +} + +(** Scope - a mapping from names to symbols *) +type scope = { + scope_id: int; + scope_parent: scope option; + scope_symbols: (string, symbol) Hashtbl.t; + scope_kind: scope_kind; +} + +and scope_kind = + | ScopeGlobal (** Top-level scope *) + | ScopeModule (** Module scope *) + | ScopeFunction (** Function body *) + | ScopeBlock (** Block expression *) + | ScopeLoop (** Loop body (for break/continue) *) + | ScopeMatch (** Match arm *) + | ScopeImpl (** Impl block *) + | ScopeTrait (** Trait definition *) +[@@deriving show] + +let scope_counter = ref 0 +let fresh_scope_id () = + let id = !scope_counter in + incr scope_counter; + id + +(** Create a new scope *) +let make_scope ?parent kind = { + scope_id = fresh_scope_id (); + scope_parent = parent; + scope_symbols = Hashtbl.create 16; + scope_kind = kind; +} + +(** Add a symbol to scope *) +let add_symbol scope symbol = + Hashtbl.replace scope.scope_symbols symbol.sym_name symbol + +(** Look up a symbol in current scope only *) +let find_local scope name = + Hashtbl.find_opt scope.scope_symbols name + +(** Look up a symbol in scope chain *) +let rec find_symbol scope name = + match Hashtbl.find_opt scope.scope_symbols name with + | Some sym -> Some sym + | None -> + match scope.scope_parent with + | Some parent -> find_symbol parent name + | None -> None + +(** Check if we're inside a loop (for break/continue) *) +let rec in_loop scope = + match scope.scope_kind with + | ScopeLoop -> true + | _ -> + match scope.scope_parent with + | Some parent -> in_loop parent + | None -> false + +(** Check if we're inside a function (for return) *) +let rec in_function scope = + match scope.scope_kind with + | ScopeFunction -> true + | _ -> + match scope.scope_parent with + | Some parent -> in_function parent + | None -> false + +(** Get all symbols in current scope *) +let symbols_in_scope scope = + Hashtbl.fold (fun _ sym acc -> sym :: acc) scope.scope_symbols [] + +(** Module path *) +type module_path = string list +[@@deriving show, eq] + +(** Module definition *) +type module_def = { + mod_path: module_path; + mod_scope: scope; + mod_exports: (string, symbol) Hashtbl.t; +} + +(** Create a new module *) +let make_module path parent_scope = { + mod_path = path; + mod_scope = make_scope ~parent:parent_scope ScopeModule; + mod_exports = Hashtbl.create 16; +} + +(** Export a symbol from module *) +let export_symbol mod_def symbol = + if symbol.sym_visibility = VisPublic then + Hashtbl.replace mod_def.mod_exports symbol.sym_name symbol + +(** Symbol table - global registry of all symbols *) +type symbol_table = { + st_symbols: (symbol_id, symbol) Hashtbl.t; + st_modules: (module_path, module_def) Hashtbl.t; + st_global_scope: scope; + mutable st_current_scope: scope; +} + +(** Create a new symbol table *) +let create () = + let global = make_scope ScopeGlobal in + { + st_symbols = Hashtbl.create 256; + st_modules = Hashtbl.create 16; + st_global_scope = global; + st_current_scope = global; + } + +(** Register a symbol in the table *) +let register st symbol = + Hashtbl.replace st.st_symbols symbol.sym_id symbol; + add_symbol st.st_current_scope symbol; + symbol + +(** Enter a new scope *) +let enter_scope st kind = + let new_scope = make_scope ~parent:st.st_current_scope kind in + st.st_current_scope <- new_scope; + new_scope + +(** Exit current scope *) +let exit_scope st = + match st.st_current_scope.scope_parent with + | Some parent -> st.st_current_scope <- parent + | None -> failwith "Cannot exit global scope" + +(** Lookup symbol by ID *) +let lookup_id st id = + Hashtbl.find_opt st.st_symbols id + +(** Lookup symbol by name in current scope chain *) +let lookup st name = + find_symbol st.st_current_scope name + +(** Check if name is defined in current scope only *) +let is_defined_local st name = + Option.is_some (find_local st.st_current_scope name) + +(** Get current scope *) +let current_scope st = st.st_current_scope + +(** Get global scope *) +let global_scope st = st.st_global_scope + +(** Pretty print symbol for debugging *) +let pp_symbol_short fmt sym = + Format.fprintf fmt "%s#%d (%a)" sym.sym_name sym.sym_id pp_symbol_kind sym.sym_kind diff --git a/lib/typecheck.ml b/lib/typecheck.ml new file mode 100644 index 0000000..967ff2d --- /dev/null +++ b/lib/typecheck.ml @@ -0,0 +1,789 @@ +(** Bidirectional type checker for AffineScript *) + +open Ast +open Types + +(** Type errors *) +type error = + | TypeMismatch of ty * ty * Span.t option + | UnboundVariable of string * Span.t + | UnboundType of string * Span.t + | NotAFunction of ty * Span.t option + | WrongArity of int * int * Span.t option + | OccursCheck of tyvar_id * ty + | CannotUnify of ty * ty + | InfiniteType of tyvar_id * ty + | RowMismatch of string * Span.t option + | MissingField of string * Span.t option + | DuplicateField of string * Span.t + | NotMutable of string * Span.t + | LinearityViolation of string * Span.t + | EffectNotAllowed of effect * Span.t option + +exception Type_error of error + +let error_to_string = function + | TypeMismatch (expected, got, _) -> + Printf.sprintf "Type mismatch: expected %s, got %s" + (show_ty expected) (show_ty got) + | UnboundVariable (name, _) -> + Printf.sprintf "Unbound variable: %s" name + | UnboundType (name, _) -> + Printf.sprintf "Unbound type: %s" name + | NotAFunction (ty, _) -> + Printf.sprintf "Expected function, got %s" (show_ty ty) + | WrongArity (expected, got, _) -> + Printf.sprintf "Expected %d arguments, got %d" expected got + | OccursCheck (id, ty) -> + Printf.sprintf "Occurs check failed: ?%d in %s" id (show_ty ty) + | CannotUnify (t1, t2) -> + Printf.sprintf "Cannot unify %s with %s" (show_ty t1) (show_ty t2) + | InfiniteType (id, ty) -> + Printf.sprintf "Infinite type: ?%d = %s" id (show_ty ty) + | RowMismatch (label, _) -> + Printf.sprintf "Row mismatch at label: %s" label + | MissingField (field, _) -> + Printf.sprintf "Missing field: %s" field + | DuplicateField (field, _) -> + Printf.sprintf "Duplicate field: %s" field + | NotMutable (name, _) -> + Printf.sprintf "Cannot mutate immutable binding: %s" name + | LinearityViolation (name, _) -> + Printf.sprintf "Linear variable used multiple times: %s" name + | EffectNotAllowed (eff, _) -> + Printf.sprintf "Effect not allowed: %a" pp_eff eff + +(** Type environment *) +type env = { + env_vars: (string, scheme) Hashtbl.t; + env_types: (string, ty) Hashtbl.t; + env_parent: env option; + env_effects: effect; (** Allowed effects in this context *) +} + +let empty_env () = { + env_vars = Hashtbl.create 32; + env_types = Hashtbl.create 16; + env_parent = None; + env_effects = EEmpty; +} + +let child_env parent = { + env_vars = Hashtbl.create 16; + env_types = Hashtbl.create 8; + env_parent = Some parent; + env_effects = parent.env_effects; +} + +let with_effects env eff = { env with env_effects = eff } + +(** Add variable binding to environment *) +let bind_var env name scheme = + Hashtbl.replace env.env_vars name scheme + +(** Add type binding to environment *) +let bind_type env name ty = + Hashtbl.replace env.env_types name ty + +(** Lookup variable in environment *) +let rec lookup_var env name = + match Hashtbl.find_opt env.env_vars name with + | Some scheme -> Some scheme + | None -> Option.bind env.env_parent (fun p -> lookup_var p name) + +(** Lookup type in environment *) +let rec lookup_type env name = + match Hashtbl.find_opt env.env_types name with + | Some ty -> Some ty + | None -> Option.bind env.env_parent (fun p -> lookup_type p name) + +(** Type checker state *) +type state = { + mutable subst: subst; + mutable errors: error list; +} + +let create_state () = { + subst = empty_subst (); + errors = []; +} + +let add_error st err = + st.errors <- err :: st.errors + +(** Instantiate a type scheme with fresh type variables *) +let instantiate st scheme = + let ty_map = Hashtbl.create 8 in + List.iter (fun (name, _) -> + Hashtbl.replace ty_map name (TVar (fresh_tyvar ())) + ) scheme.sc_tyvars; + let rec subst_rigid ty = + match ty with + | TRigid name -> + (match Hashtbl.find_opt ty_map name with + | Some t -> t + | None -> ty) + | TApp (name, args) -> TApp (name, List.map subst_rigid args) + | TArrow (t1, t2, eff) -> TArrow (subst_rigid t1, subst_rigid t2, eff) + | TForall (v, k, t) -> TForall (v, k, subst_rigid t) + | TTuple ts -> TTuple (List.map subst_rigid ts) + | TRecord row -> TRecord (subst_rigid_row row) + | TRef t -> TRef (subst_rigid t) + | TMut t -> TMut (subst_rigid t) + | TOwn t -> TOwn (subst_rigid t) + | TRefined (t, r) -> TRefined (subst_rigid t, r) + | TQuantified (q, t) -> TQuantified (q, subst_rigid t) + | _ -> ty + and subst_rigid_row = function + | REmpty -> REmpty + | RVar id -> RVar id + | RExtend (l, t, r) -> RExtend (l, subst_rigid t, subst_rigid_row r) + in + subst_rigid scheme.sc_type + +(** Generalize a type to a scheme *) +let generalize env ty = + let env_vars = Hashtbl.fold (fun _ scheme acc -> + let ty = scheme.sc_type in + free_tyvars ty @ acc + ) env.env_vars [] in + let ty_vars = free_tyvars ty in + let free = List.filter (fun v -> not (List.mem v env_vars)) ty_vars in + if free = [] then + mono ty + else + let tyvars = List.mapi (fun i id -> + let name = Printf.sprintf "t%d" i in + (name, KType) + ) (List.sort_uniq compare free) in + { sc_tyvars = tyvars; sc_rowvars = []; sc_effvars = []; sc_type = ty } + +(** Unification *) +let rec unify st t1 t2 = + let t1 = apply_subst st.subst t1 in + let t2 = apply_subst st.subst t2 in + match t1, t2 with + | TUnit, TUnit -> () + | TBool, TBool -> () + | TInt, TInt -> () + | TNat, TNat -> () + | TFloat, TFloat -> () + | TChar, TChar -> () + | TString, TString -> () + | TNever, _ -> () (* Never unifies with anything *) + | _, TNever -> () + | TVar id1, TVar id2 when id1 = id2 -> () + | TVar id, t | t, TVar id -> + if occurs id t then + raise (Type_error (InfiniteType (id, t))) + else + extend_ty_subst st.subst id t + | TRigid n1, TRigid n2 when n1 = n2 -> () + | TApp (n1, args1), TApp (n2, args2) when n1 = n2 -> + if List.length args1 <> List.length args2 then + raise (Type_error (CannotUnify (t1, t2))); + List.iter2 (unify st) args1 args2 + | TArrow (a1, r1, e1), TArrow (a2, r2, e2) -> + unify st a1 a2; + unify st r1 r2; + unify_effect st e1 e2 + | TTuple ts1, TTuple ts2 when List.length ts1 = List.length ts2 -> + List.iter2 (unify st) ts1 ts2 + | TRecord r1, TRecord r2 -> + unify_row st r1 r2 + | TRef t1, TRef t2 -> unify st t1 t2 + | TMut t1, TMut t2 -> unify st t1 t2 + | TOwn t1, TOwn t2 -> unify st t1 t2 + | TQuantified (q1, t1), TQuantified (q2, t2) when q1 = q2 -> + unify st t1 t2 + | _ -> raise (Type_error (CannotUnify (t1, t2))) + +and unify_row st r1 r2 = + let r1 = apply_row_subst st.subst r1 in + let r2 = apply_row_subst st.subst r2 in + match r1, r2 with + | REmpty, REmpty -> () + | RVar id1, RVar id2 when id1 = id2 -> () + | RVar id, r | r, RVar id -> + extend_row_subst st.subst id r + | RExtend (l1, t1, r1'), RExtend (l2, t2, r2') when l1 = l2 -> + unify st t1 t2; + unify_row st r1' r2' + | RExtend (l1, t1, r1'), r2 -> + (* Row rewriting: find l1 in r2 *) + let rec find_and_remove label = function + | REmpty -> None + | RVar _ as r -> Some (TVar (fresh_tyvar ()), r) (* Extend with fresh *) + | RExtend (l, t, rest) when l = label -> + Some (t, rest) + | RExtend (l, t, rest) -> + match find_and_remove label rest with + | Some (found_t, rest') -> Some (found_t, RExtend (l, t, rest')) + | None -> None + in + (match find_and_remove l1 r2 with + | Some (t2, r2') -> + unify st t1 t2; + unify_row st r1' r2' + | None -> + raise (Type_error (RowMismatch (l1, None)))) + | _ -> raise (Type_error (CannotUnify (TRecord r1, TRecord r2))) + +and unify_effect st e1 e2 = + let e1 = apply_eff_subst st.subst e1 in + let e2 = apply_eff_subst st.subst e2 in + match e1, e2 with + | EEmpty, EEmpty -> () + | EVar id1, EVar id2 when id1 = id2 -> () + | EVar id, e | e, EVar id -> + extend_eff_subst st.subst id e + | ECon (n1, args1), ECon (n2, args2) when n1 = n2 -> + List.iter2 (unify st) args1 args2 + | EUnion (e1a, e1b), e2 -> + (* Effect subsumption - both parts must be in e2 *) + unify_effect st e1a e2; + unify_effect st e1b e2 + | e1, EUnion (e2a, e2b) -> + unify_effect st e1 e2a; + unify_effect st e1 e2b + | _ -> () (* Effects are more permissive *) + +(** Convert AST type expression to internal type *) +let rec ast_to_type env (ty_expr: type_expr) : ty = + match ty_expr with + | TyVar id | TyCon id -> + (match id.name with + | "Int" -> TInt + | "Bool" -> TBool + | "Float" -> TFloat + | "Char" -> TChar + | "String" -> TString + | "Unit" -> TUnit + | "Never" -> TNever + | "Nat" -> TNat + | name -> + match lookup_type env name with + | Some ty -> ty + | None -> TRigid name) (* Assume it's a type variable *) + | TyApp (id, args) -> + let arg_types = List.map (function + | TyArg t -> ast_to_type env t + | NatArg _ -> TInt (* Simplification: treat nat args as int *) + ) args in + TApp (id.name, arg_types) + | TyArrow (t1, t2, eff_opt) -> + let eff = match eff_opt with + | None -> EEmpty + | Some e -> ast_to_effect env e + in + TArrow (ast_to_type env t1, ast_to_type env t2, eff) + | TyDepArrow { da_param_ty; da_ret_ty; da_eff; _ } -> + (* Simplify dependent arrow to regular arrow for now *) + let eff = match da_eff with + | None -> EEmpty + | Some e -> ast_to_effect env e + in + TArrow (ast_to_type env da_param_ty, ast_to_type env da_ret_ty, eff) + | TyTuple tys -> + TTuple (List.map (ast_to_type env) tys) + | TyRecord (fields, rest) -> + let row = List.fold_right (fun rf acc -> + RExtend (rf.rf_name.name, ast_to_type env rf.rf_ty, acc) + ) fields (match rest with + | None -> REmpty + | Some _ -> RVar (fresh_rowvar ())) + in + TRecord row + | TyOwn t -> TOwn (ast_to_type env t) + | TyRef t -> TRef (ast_to_type env t) + | TyMut t -> TMut (ast_to_type env t) + | TyRefined (t, _) -> ast_to_type env t (* Ignore refinement for now *) + | TyHole -> TVar (fresh_tyvar ()) + +and ast_to_effect _env (eff_expr: effect_expr) : effect = + match eff_expr with + | EffVar id -> ECon (id.name, []) + | EffCon (id, _) -> ECon (id.name, []) + | EffUnion (e1, e2) -> + EUnion (ast_to_effect _env e1, ast_to_effect _env e2) + +(** Synthesize type of expression *) +let rec synth st env (expr: expr) : ty = + match expr with + | ExprSpan (e, _) -> synth st env e + + | ExprLit lit -> synth_literal lit + + | ExprVar id -> + (match lookup_var env id.name with + | Some scheme -> instantiate st scheme + | None -> + add_error st (UnboundVariable (id.name, id.span)); + TVar (fresh_tyvar ())) + + | ExprLet { el_mut = _; el_pat; el_ty; el_value; el_body } -> + let value_ty = match el_ty with + | Some ty_expr -> + let expected = ast_to_type env ty_expr in + check st env el_value expected; + expected + | None -> synth st env el_value + in + let env' = child_env env in + bind_pattern env' el_pat value_ty; + (match el_body with + | Some body -> synth st env' body + | None -> TUnit) + + | ExprIf { ei_cond; ei_then; ei_else } -> + check st env ei_cond TBool; + let then_ty = synth st env ei_then in + (match ei_else with + | Some else_e -> + check st env else_e then_ty; + then_ty + | None -> TUnit) + + | ExprMatch { em_scrutinee; em_arms } -> + let scrutinee_ty = synth st env em_scrutinee in + (match em_arms with + | [] -> TVar (fresh_tyvar ()) + | arm :: rest -> + let env' = child_env env in + check_pattern st env' arm.ma_pat scrutinee_ty; + Option.iter (fun guard -> check st env' guard TBool) arm.ma_guard; + let result_ty = synth st env' arm.ma_body in + List.iter (fun arm -> + let arm_env = child_env env in + check_pattern st arm_env arm.ma_pat scrutinee_ty; + Option.iter (fun guard -> check st arm_env guard TBool) arm.ma_guard; + check st arm_env arm.ma_body result_ty + ) rest; + result_ty) + + | ExprLambda { elam_params; elam_ret_ty; elam_body } -> + let env' = child_env env in + let param_types = List.map (fun p -> + let ty = ast_to_type env p.p_ty in + bind_var env' p.p_name.name (mono ty); + ty + ) elam_params in + let body_ty = match elam_ret_ty with + | Some ret_ty -> + let expected = ast_to_type env ret_ty in + check st env' elam_body expected; + expected + | None -> synth st env' elam_body + in + List.fold_right (fun param_ty acc -> + TArrow (param_ty, acc, env.env_effects) + ) param_types body_ty + + | ExprApp (func, args) -> + let func_ty = synth st env func in + synth_app st env func_ty args + + | ExprField (e, field) -> + let record_ty = synth st env e in + synth_field st record_ty field.name + + | ExprTupleIndex (e, idx) -> + let tuple_ty = synth st env e in + synth_tuple_index st tuple_ty idx + + | ExprIndex (arr, idx) -> + let arr_ty = synth st env arr in + check st env idx TInt; + synth_index st arr_ty + + | ExprTuple exprs -> + TTuple (List.map (synth st env) exprs) + + | ExprArray exprs -> + (match exprs with + | [] -> TApp ("Array", [TVar (fresh_tyvar ())]) + | e :: rest -> + let elem_ty = synth st env e in + List.iter (fun e -> check st env e elem_ty) rest; + TApp ("Array", [elem_ty])) + + | ExprRecord { er_fields; er_spread } -> + let base_row = match er_spread with + | Some spread -> + let spread_ty = synth st env spread in + (match apply_subst st.subst spread_ty with + | TRecord row -> row + | _ -> REmpty) + | None -> REmpty + in + let row = List.fold_right (fun (field_id, expr_opt) acc -> + let ty = match expr_opt with + | Some e -> synth st env e + | None -> + (* Shorthand: {x} means {x: x} *) + match lookup_var env field_id.name with + | Some scheme -> instantiate st scheme + | None -> + add_error st (UnboundVariable (field_id.name, field_id.span)); + TVar (fresh_tyvar ()) + in + RExtend (field_id.name, ty, acc) + ) er_fields base_row in + TRecord row + + | ExprRowRestrict (e, field) -> + let record_ty = synth st env e in + synth_row_restrict st record_ty field.name + + | ExprBinary (e1, op, e2) -> + synth_binary st env op e1 e2 + + | ExprUnary (op, e) -> + synth_unary st env op e + + | ExprBlock blk -> + synth_block st env blk + + | ExprReturn e_opt -> + (match e_opt with + | Some e -> ignore (synth st env e) + | None -> ()); + TNever (* Return diverges *) + + | ExprVariant (type_id, _variant_id) -> + (* For now, return a type variable that will be unified *) + TApp (type_id.name, []) + + | ExprTry _ -> TVar (fresh_tyvar ()) (* TODO *) + | ExprHandle _ -> TVar (fresh_tyvar ()) (* TODO *) + | ExprResume _ -> TVar (fresh_tyvar ()) (* TODO *) + | ExprUnsafe _ -> TVar (fresh_tyvar ()) (* TODO *) + +and synth_literal = function + | LitInt _ -> TInt + | LitFloat _ -> TFloat + | LitBool _ -> TBool + | LitChar _ -> TChar + | LitString _ -> TString + | LitUnit _ -> TUnit + +and synth_app st env func_ty args = + let func_ty = apply_subst st.subst func_ty in + match func_ty, args with + | _, [] -> func_ty + | TArrow (param_ty, ret_ty, _), arg :: rest -> + check st env arg param_ty; + synth_app st env ret_ty rest + | TVar id, arg :: rest -> + let arg_ty = synth st env arg in + let ret_ty = TVar (fresh_tyvar ()) in + extend_ty_subst st.subst id (TArrow (arg_ty, ret_ty, EEmpty)); + synth_app st env ret_ty rest + | _, _ -> + add_error st (NotAFunction (func_ty, None)); + TVar (fresh_tyvar ()) + +and synth_field st record_ty field_name = + let record_ty = apply_subst st.subst record_ty in + match record_ty with + | TRecord row -> + let rec find_field = function + | REmpty -> + add_error st (MissingField (field_name, None)); + TVar (fresh_tyvar ()) + | RVar _ -> + TVar (fresh_tyvar ()) (* Unknown row - return fresh *) + | RExtend (l, t, rest) -> + if l = field_name then t + else find_field rest + in + find_field row + | TVar id -> + let field_ty = TVar (fresh_tyvar ()) in + let rest_row = RVar (fresh_rowvar ()) in + let record_row = RExtend (field_name, field_ty, rest_row) in + extend_ty_subst st.subst id (TRecord record_row); + field_ty + | _ -> + add_error st (MissingField (field_name, None)); + TVar (fresh_tyvar ()) + +and synth_tuple_index st tuple_ty idx = + let tuple_ty = apply_subst st.subst tuple_ty in + match tuple_ty with + | TTuple ts when idx >= 0 && idx < List.length ts -> + List.nth ts idx + | _ -> TVar (fresh_tyvar ()) + +and synth_index st arr_ty = + let arr_ty = apply_subst st.subst arr_ty in + match arr_ty with + | TApp ("Array", [elem_ty]) -> elem_ty + | TString -> TChar + | _ -> TVar (fresh_tyvar ()) + +and synth_row_restrict st record_ty field_name = + let record_ty = apply_subst st.subst record_ty in + match record_ty with + | TRecord row -> + let rec remove_field = function + | REmpty -> REmpty + | RVar id -> RVar id + | RExtend (l, t, rest) -> + if l = field_name then rest + else RExtend (l, t, remove_field rest) + in + TRecord (remove_field row) + | _ -> TVar (fresh_tyvar ()) + +and synth_binary st env op e1 e2 = + match op with + | OpAdd | OpSub | OpMul | OpDiv | OpMod -> + (* Try int first, then float *) + let t1 = synth st env e1 in + let t2 = synth st env e2 in + (try unify st t1 TInt; unify st t2 TInt; TInt + with Type_error _ -> + try unify st t1 TFloat; unify st t2 TFloat; TFloat + with Type_error _ -> + (* For + also allow string concat *) + if op = OpAdd then begin + try unify st t1 TString; unify st t2 TString; TString + with Type_error _ -> TInt + end else TInt) + | OpEq | OpNe -> + let t1 = synth st env e1 in + check st env e2 t1; + TBool + | OpLt | OpLe | OpGt | OpGe -> + let t1 = synth st env e1 in + check st env e2 t1; + TBool + | OpAnd | OpOr -> + check st env e1 TBool; + check st env e2 TBool; + TBool + | OpBitAnd | OpBitOr | OpBitXor | OpShl | OpShr -> + check st env e1 TInt; + check st env e2 TInt; + TInt + +and synth_unary st env op e = + match op with + | OpNeg -> + let t = synth st env e in + (try unify st t TInt; TInt + with Type_error _ -> unify st t TFloat; TFloat) + | OpNot -> check st env e TBool; TBool + | OpBitNot -> check st env e TInt; TInt + | OpRef -> TRef (synth st env e) + | OpDeref -> + let t = synth st env e in + let elem_ty = TVar (fresh_tyvar ()) in + (try unify st t (TRef elem_ty); elem_ty + with Type_error _ -> + try unify st t (TMut elem_ty); elem_ty + with Type_error _ -> elem_ty) + +and synth_block st env { blk_stmts; blk_expr } = + let env' = child_env env in + List.iter (check_stmt st env') blk_stmts; + match blk_expr with + | Some e -> synth st env' e + | None -> TUnit + +and check_stmt st env = function + | StmtLet { sl_mut = _; sl_pat; sl_ty; sl_value } -> + let value_ty = match sl_ty with + | Some ty_expr -> + let expected = ast_to_type env ty_expr in + check st env sl_value expected; + expected + | None -> synth st env sl_value + in + bind_pattern env sl_pat value_ty + | StmtExpr e -> + ignore (synth st env e) + | StmtAssign (target, _, value) -> + let target_ty = synth st env target in + check st env value target_ty + | StmtWhile (cond, body) -> + check st env cond TBool; + ignore (synth_block st env body) + | StmtFor (pat, iter, body) -> + let iter_ty = synth st env iter in + let elem_ty = match apply_subst st.subst iter_ty with + | TApp ("Array", [t]) -> t + | TString -> TChar + | _ -> TVar (fresh_tyvar ()) + in + let env' = child_env env in + bind_pattern env' pat elem_ty; + ignore (synth_block st env' body) + +(** Check expression against expected type *) +and check st env expr expected = + let actual = synth st env expr in + try unify st actual expected + with Type_error _ -> + add_error st (TypeMismatch (expected, actual, None)) + +(** Bind pattern variables with types *) +and bind_pattern env pat ty = + match pat with + | PatWildcard _ -> () + | PatVar id -> bind_var env id.name (mono ty) + | PatLit _ -> () + | PatCon (_, pats) -> + (* For now, assume tuple-like destructuring *) + (match ty with + | TTuple ts when List.length ts = List.length pats -> + List.iter2 (bind_pattern env) pats ts + | _ -> ()) + | PatTuple pats -> + (match ty with + | TTuple ts when List.length ts = List.length pats -> + List.iter2 (bind_pattern env) pats ts + | _ -> ()) + | PatRecord (fields, _) -> + (match ty with + | TRecord row -> + List.iter (fun (field_id, pat_opt) -> + let field_ty = find_row_field row field_id.name in + match pat_opt with + | Some p -> bind_pattern env p field_ty + | None -> bind_var env field_id.name (mono field_ty) + ) fields + | _ -> ()) + | PatOr (p1, _) -> bind_pattern env p1 ty + | PatAs (id, p) -> + bind_var env id.name (mono ty); + bind_pattern env p ty + +and find_row_field row field_name = + match row with + | REmpty -> TVar (fresh_tyvar ()) + | RVar _ -> TVar (fresh_tyvar ()) + | RExtend (l, t, rest) -> + if l = field_name then t + else find_row_field rest field_name + +and check_pattern st env pat expected = + match pat with + | PatWildcard _ -> () + | PatVar id -> bind_var env id.name (mono expected) + | PatLit lit -> + let lit_ty = synth_literal lit in + (try unify st lit_ty expected with _ -> ()) + | PatCon (_, pats) -> + (* Simplified: assume it's a variant with tuple fields *) + (match apply_subst st.subst expected with + | TTuple ts when List.length ts = List.length pats -> + List.iter2 (check_pattern st env) pats ts + | _ -> List.iter (fun p -> check_pattern st env p (TVar (fresh_tyvar ()))) pats) + | PatTuple pats -> + (match apply_subst st.subst expected with + | TTuple ts when List.length ts = List.length pats -> + List.iter2 (check_pattern st env) pats ts + | _ -> ()) + | PatRecord (fields, _) -> + List.iter (fun (field_id, pat_opt) -> + let field_ty = match apply_subst st.subst expected with + | TRecord row -> find_row_field row field_id.name + | _ -> TVar (fresh_tyvar ()) + in + match pat_opt with + | Some p -> check_pattern st env p field_ty + | None -> bind_var env field_id.name (mono field_ty) + ) fields + | PatOr (p1, p2) -> + check_pattern st env p1 expected; + check_pattern st env p2 expected + | PatAs (id, p) -> + bind_var env id.name (mono expected); + check_pattern st env p expected + +(** Type check a function declaration *) +let check_fn_decl st env (decl: fn_decl) = + let env' = child_env env in + (* Add type parameters *) + List.iter (fun tp -> + bind_type env' tp.tp_name.name (TRigid tp.tp_name.name) + ) decl.fd_type_params; + (* Add parameters *) + let param_types = List.map (fun p -> + let ty = ast_to_type env' p.p_ty in + bind_var env' p.p_name.name (mono ty); + ty + ) decl.fd_params in + (* Check return type *) + let ret_ty = match decl.fd_ret_ty with + | Some ty_expr -> ast_to_type env' ty_expr + | None -> TVar (fresh_tyvar ()) + in + (* Check body *) + let env_with_effects = match decl.fd_eff with + | Some eff -> with_effects env' (ast_to_effect env' eff) + | None -> env' + in + (match decl.fd_body with + | FnBlock blk -> check st env_with_effects (ExprBlock blk) ret_ty + | FnExpr e -> check st env_with_effects e ret_ty); + (* Return function type *) + List.fold_right (fun param_ty acc -> + TArrow (param_ty, acc, env_with_effects.env_effects) + ) param_types ret_ty + +(** Type check a program *) +let check_program (prog: program) = + let st = create_state () in + let env = empty_env () in + (* Add built-in types *) + bind_type env "Int" TInt; + bind_type env "Bool" TBool; + bind_type env "Float" TFloat; + bind_type env "Char" TChar; + bind_type env "String" TString; + bind_type env "Unit" TUnit; + bind_type env "Never" TNever; + bind_type env "Nat" TNat; + (* Add built-in functions *) + bind_var env "print" (mono (TArrow (TVar (fresh_tyvar ()), TUnit, EEmpty))); + bind_var env "println" (mono (TArrow (TVar (fresh_tyvar ()), TUnit, EEmpty))); + bind_var env "str" (mono (TArrow (TVar (fresh_tyvar ()), TString, EEmpty))); + bind_var env "len" (mono (TArrow (TVar (fresh_tyvar ()), TInt, EEmpty))); + bind_var env "type_of" (mono (TArrow (TVar (fresh_tyvar ()), TString, EEmpty))); + bind_var env "range" (mono (TArrow (TInt, TApp ("Array", [TInt]), EEmpty))); + bind_var env "map" (mono (TArrow (TArrow (TVar 0, TVar 1, EEmpty), TArrow (TApp ("Array", [TVar 0]), TApp ("Array", [TVar 1]), EEmpty), EEmpty))); + bind_var env "filter" (mono (TArrow (TArrow (TVar 0, TBool, EEmpty), TArrow (TApp ("Array", [TVar 0]), TApp ("Array", [TVar 0]), EEmpty), EEmpty))); + bind_var env "fold" (mono (TArrow (TArrow (TVar 0, TArrow (TVar 1, TVar 0, EEmpty), EEmpty), TArrow (TVar 0, TArrow (TApp ("Array", [TVar 1]), TVar 0, EEmpty), EEmpty), EEmpty))); + (* First pass: collect type declarations *) + List.iter (fun decl -> + match decl with + | TopType td -> + bind_type env td.td_name.name (TApp (td.td_name.name, [])) + | _ -> () + ) prog.prog_decls; + (* Second pass: collect function declarations *) + List.iter (fun decl -> + match decl with + | TopFn fd -> + let fn_ty = TVar (fresh_tyvar ()) in + bind_var env fd.fd_name.name (mono fn_ty) + | TopConst { tc_name; tc_ty; _ } -> + bind_var env tc_name.name (mono (ast_to_type env tc_ty)) + | _ -> () + ) prog.prog_decls; + (* Third pass: type check all declarations *) + List.iter (fun decl -> + match decl with + | TopFn fd -> + let fn_ty = check_fn_decl st env fd in + (* Update the binding with the inferred type *) + bind_var env fd.fd_name.name (generalize env fn_ty) + | TopConst { tc_name; tc_ty; tc_value; _ } -> + let expected = ast_to_type env tc_ty in + check st env tc_value expected + | _ -> () + ) prog.prog_decls; + List.rev st.errors diff --git a/lib/types.ml b/lib/types.ml new file mode 100644 index 0000000..825758d --- /dev/null +++ b/lib/types.ml @@ -0,0 +1,287 @@ +(** Internal type representation for type checking *) + +(** Type variable ID *) +type tyvar_id = int +[@@deriving show, eq, ord] + +let next_tyvar = ref 0 +let fresh_tyvar () = + let id = !next_tyvar in + incr next_tyvar; + id + +(** Row variable ID *) +type rowvar_id = int +[@@deriving show, eq, ord] + +let next_rowvar = ref 0 +let fresh_rowvar () = + let id = !next_rowvar in + incr next_rowvar; + id + +(** Effect variable ID *) +type effvar_id = int +[@@deriving show, eq, ord] + +let next_effvar = ref 0 +let fresh_effvar () = + let id = !next_effvar in + incr next_effvar; + id + +(** Quantity (QTT) *) +type quantity = + | QZero (** Erased - compile time only *) + | QOne (** Linear - exactly once *) + | QOmega (** Unrestricted *) +[@@deriving show, eq] + +(** Multiply quantities (semiring) *) +let mult_quantity q1 q2 = + match q1, q2 with + | QZero, _ | _, QZero -> QZero + | QOne, QOne -> QOne + | QOmega, _ | _, QOmega -> QOmega + +(** Add quantities (join in semiring) *) +let add_quantity q1 q2 = + match q1, q2 with + | QZero, q | q, QZero -> q + | QOne, QOne -> QOmega (* Used twice = unrestricted *) + | QOmega, _ | _, QOmega -> QOmega + +(** Internal type representation *) +type ty = + | TUnit + | TBool + | TInt + | TNat + | TFloat + | TChar + | TString + | TNever (** Bottom type *) + | TVar of tyvar_id (** Unification variable *) + | TRigid of string (** Rigid type variable (skolem) *) + | TApp of string * ty list (** Type constructor application *) + | TArrow of ty * ty * effect (** Function type: T -> U / E *) + | TForall of string * kind * ty (** Polymorphic type *) + | TTuple of ty list (** Tuple type *) + | TRecord of row (** Record type with row *) + | TRef of ty (** Reference type (immutable borrow) *) + | TMut of ty (** Mutable reference *) + | TOwn of ty (** Owned type *) + | TRefined of ty * refinement (** Refinement type *) + | TQuantified of quantity * ty (** Quantity-annotated type *) + +(** Row type for records *) +and row = + | REmpty (** {} *) + | RVar of rowvar_id (** Row variable *) + | RExtend of string * ty * row (** {l: T | r} *) + +(** Effect type *) +and effect = + | EEmpty (** Pure *) + | EVar of effvar_id (** Effect variable *) + | ECon of string * ty list (** Effect constructor *) + | EUnion of effect * effect (** Effect union *) + +(** Refinement predicate (simplified) *) +and refinement = + | RTrue + | RFalse + | REq of string * int (** x = n *) + | RLt of string * int (** x < n *) + | RGt of string * int (** x > n *) + | RAnd of refinement * refinement + | ROr of refinement * refinement + | RNot of refinement + +(** Kind *) +and kind = + | KType (** Type *) + | KNat (** Natural number *) + | KRow (** Row *) + | KEffect (** Effect *) + | KArrow of kind * kind (** κ → κ *) +[@@deriving show, eq] + +(** Type scheme - polymorphic type *) +type scheme = { + sc_tyvars: (string * kind) list; (** Bound type variables *) + sc_rowvars: string list; (** Bound row variables *) + sc_effvars: string list; (** Bound effect variables *) + sc_type: ty; (** The type *) +} +[@@deriving show] + +(** Create a monomorphic scheme *) +let mono ty = { + sc_tyvars = []; + sc_rowvars = []; + sc_effvars = []; + sc_type = ty; +} + +(** Substitution for type variables *) +type subst = { + ty_subst: (tyvar_id, ty) Hashtbl.t; + row_subst: (rowvar_id, row) Hashtbl.t; + eff_subst: (effvar_id, effect) Hashtbl.t; +} + +let empty_subst () = { + ty_subst = Hashtbl.create 16; + row_subst = Hashtbl.create 8; + eff_subst = Hashtbl.create 8; +} + +(** Apply substitution to type *) +let rec apply_subst subst ty = + match ty with + | TUnit | TBool | TInt | TNat | TFloat | TChar | TString | TNever | TRigid _ -> ty + | TVar id -> + (match Hashtbl.find_opt subst.ty_subst id with + | Some t -> apply_subst subst t + | None -> ty) + | TApp (name, args) -> + TApp (name, List.map (apply_subst subst) args) + | TArrow (t1, t2, eff) -> + TArrow (apply_subst subst t1, apply_subst subst t2, apply_eff_subst subst eff) + | TForall (v, k, t) -> + TForall (v, k, apply_subst subst t) + | TTuple ts -> + TTuple (List.map (apply_subst subst) ts) + | TRecord row -> + TRecord (apply_row_subst subst row) + | TRef t -> TRef (apply_subst subst t) + | TMut t -> TMut (apply_subst subst t) + | TOwn t -> TOwn (apply_subst subst t) + | TRefined (t, r) -> TRefined (apply_subst subst t, r) + | TQuantified (q, t) -> TQuantified (q, apply_subst subst t) + +and apply_row_subst subst = function + | REmpty -> REmpty + | RVar id -> + (match Hashtbl.find_opt subst.row_subst id with + | Some r -> apply_row_subst subst r + | None -> RVar id) + | RExtend (l, t, r) -> + RExtend (l, apply_subst subst t, apply_row_subst subst r) + +and apply_eff_subst subst = function + | EEmpty -> EEmpty + | EVar id -> + (match Hashtbl.find_opt subst.eff_subst id with + | Some e -> apply_eff_subst subst e + | None -> EVar id) + | ECon (name, args) -> + ECon (name, List.map (apply_subst subst) args) + | EUnion (e1, e2) -> + EUnion (apply_eff_subst subst e1, apply_eff_subst subst e2) + +(** Extend substitution *) +let extend_ty_subst subst id ty = + Hashtbl.replace subst.ty_subst id ty + +let extend_row_subst subst id row = + Hashtbl.replace subst.row_subst id row + +let extend_eff_subst subst id eff = + Hashtbl.replace subst.eff_subst id eff + +(** Free type variables in a type *) +let rec free_tyvars ty = + match ty with + | TUnit | TBool | TInt | TNat | TFloat | TChar | TString | TNever | TRigid _ -> [] + | TVar id -> [id] + | TApp (_, args) -> List.concat_map free_tyvars args + | TArrow (t1, t2, _) -> free_tyvars t1 @ free_tyvars t2 + | TForall (_, _, t) -> free_tyvars t + | TTuple ts -> List.concat_map free_tyvars ts + | TRecord row -> free_tyvars_row row + | TRef t | TMut t | TOwn t -> free_tyvars t + | TRefined (t, _) -> free_tyvars t + | TQuantified (_, t) -> free_tyvars t + +and free_tyvars_row = function + | REmpty -> [] + | RVar _ -> [] + | RExtend (_, t, r) -> free_tyvars t @ free_tyvars_row r + +(** Occurs check - prevent infinite types *) +let rec occurs id ty = + match ty with + | TVar id' -> id = id' + | TApp (_, args) -> List.exists (occurs id) args + | TArrow (t1, t2, _) -> occurs id t1 || occurs id t2 + | TForall (_, _, t) -> occurs id t + | TTuple ts -> List.exists (occurs id) ts + | TRecord row -> occurs_row id row + | TRef t | TMut t | TOwn t -> occurs id t + | TRefined (t, _) -> occurs id t + | TQuantified (_, t) -> occurs id t + | _ -> false + +and occurs_row id = function + | REmpty | RVar _ -> false + | RExtend (_, t, r) -> occurs id t || occurs_row id r + +(** Pretty print type *) +let rec pp_ty fmt ty = + match ty with + | TUnit -> Format.fprintf fmt "()" + | TBool -> Format.fprintf fmt "Bool" + | TInt -> Format.fprintf fmt "Int" + | TNat -> Format.fprintf fmt "Nat" + | TFloat -> Format.fprintf fmt "Float" + | TChar -> Format.fprintf fmt "Char" + | TString -> Format.fprintf fmt "String" + | TNever -> Format.fprintf fmt "Never" + | TVar id -> Format.fprintf fmt "?%d" id + | TRigid name -> Format.fprintf fmt "%s" name + | TApp (name, []) -> Format.fprintf fmt "%s" name + | TApp (name, args) -> + Format.fprintf fmt "%s[%a]" name + (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") pp_ty) args + | TArrow (t1, t2, EEmpty) -> + Format.fprintf fmt "(%a -> %a)" pp_ty t1 pp_ty t2 + | TArrow (t1, t2, eff) -> + Format.fprintf fmt "(%a -> %a / %a)" pp_ty t1 pp_ty t2 pp_eff eff + | TForall (v, _, t) -> + Format.fprintf fmt "(forall %s. %a)" v pp_ty t + | TTuple ts -> + Format.fprintf fmt "(%a)" + (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") pp_ty) ts + | TRecord row -> + Format.fprintf fmt "{%a}" pp_row row + | TRef t -> Format.fprintf fmt "ref %a" pp_ty t + | TMut t -> Format.fprintf fmt "mut %a" pp_ty t + | TOwn t -> Format.fprintf fmt "own %a" pp_ty t + | TRefined (t, _) -> Format.fprintf fmt "%a where (...)" pp_ty t + | TQuantified (q, t) -> + let q_str = match q with QZero -> "0" | QOne -> "1" | QOmega -> "ω" in + Format.fprintf fmt "%s %a" q_str pp_ty t + +and pp_row fmt = function + | REmpty -> () + | RVar id -> Format.fprintf fmt "..%d" id + | RExtend (l, t, REmpty) -> Format.fprintf fmt "%s: %a" l pp_ty t + | RExtend (l, t, r) -> Format.fprintf fmt "%s: %a, %a" l pp_ty t pp_row r + +and pp_eff fmt = function + | EEmpty -> Format.fprintf fmt "Pure" + | EVar id -> Format.fprintf fmt "?e%d" id + | ECon (name, []) -> Format.fprintf fmt "%s" name + | ECon (name, args) -> + Format.fprintf fmt "%s[%a]" name + (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") pp_ty) args + | EUnion (e1, e2) -> Format.fprintf fmt "%a + %a" pp_eff e1 pp_eff e2 + +let show_ty ty = + let buf = Buffer.create 64 in + let fmt = Format.formatter_of_buffer buf in + pp_ty fmt ty; + Format.pp_print_flush fmt (); + Buffer.contents buf diff --git a/lib/value.ml b/lib/value.ml new file mode 100644 index 0000000..9a824f4 --- /dev/null +++ b/lib/value.ml @@ -0,0 +1,170 @@ +(** Runtime values for AffineScript interpreter *) + +(** Runtime values *) +type t = + | VUnit + | VBool of bool + | VInt of int + | VFloat of float + | VChar of char + | VString of string + | VTuple of t list + | VArray of t array + | VRecord of (string * t) list + | VVariant of string * string * t list (** type, variant, args *) + | VClosure of { + params: Ast.param list; + body: Ast.expr; + env: env; + } + | VBuiltin of string * (t list -> t) + | VRef of t ref (** Mutable reference *) + +(** Environment - variable bindings *) +and env = { + bindings: (string, binding) Hashtbl.t; + parent: env option; +} + +and binding = { + value: t; + mutable_: bool; + linear: bool; (** Linear values can only be used once *) + mutable consumed: bool; (** Track if linear value was consumed *) +} + +(** Create a new empty environment *) +let empty_env () = { + bindings = Hashtbl.create 32; + parent = None; +} + +(** Create a child environment *) +let child_env parent = { + bindings = Hashtbl.create 16; + parent = Some parent; +} + +(** Bind a variable in the current scope *) +let bind env name value ~mutable_ ~linear = + Hashtbl.replace env.bindings name { value; mutable_; linear; consumed = false } + +(** Look up a variable *) +let rec lookup env name = + match Hashtbl.find_opt env.bindings name with + | Some binding -> Some binding + | None -> Option.bind env.parent (fun p -> lookup p name) + +(** Update a mutable variable *) +let rec update env name new_value = + match Hashtbl.find_opt env.bindings name with + | Some binding when binding.mutable_ -> + Hashtbl.replace env.bindings name { binding with value = new_value }; + true + | Some _ -> false (* Not mutable *) + | None -> + match env.parent with + | Some parent -> update parent name new_value + | None -> false + +(** Mark a linear binding as consumed *) +let consume env name = + match Hashtbl.find_opt env.bindings name with + | Some binding when binding.linear -> + if binding.consumed then + Error (Printf.sprintf "Linear value '%s' already consumed" name) + else begin + binding.consumed <- true; + Ok binding.value + end + | Some binding -> Ok binding.value + | None -> + match env.parent with + | Some parent -> + (match lookup parent name with + | Some binding when binding.linear -> + if binding.consumed then + Error (Printf.sprintf "Linear value '%s' already consumed" name) + else begin + binding.consumed <- true; + Ok binding.value + end + | Some binding -> Ok binding.value + | None -> Error (Printf.sprintf "Unbound variable: %s" name)) + | None -> Error (Printf.sprintf "Unbound variable: %s" name) + +(** Pretty print a value *) +let rec pp fmt = function + | VUnit -> Format.fprintf fmt "()" + | VBool b -> Format.fprintf fmt "%b" b + | VInt i -> Format.fprintf fmt "%d" i + | VFloat f -> Format.fprintf fmt "%g" f + | VChar c -> Format.fprintf fmt "'%c'" c + | VString s -> Format.fprintf fmt "\"%s\"" (String.escaped s) + | VTuple vs -> + Format.fprintf fmt "(%a)" + (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") pp) vs + | VArray vs -> + Format.fprintf fmt "[%a]" + (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") pp) + (Array.to_list vs) + | VRecord fields -> + Format.fprintf fmt "{%a}" + (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") + (fun fmt (k, v) -> Format.fprintf fmt "%s: %a" k pp v)) + fields + | VVariant (_, variant, []) -> + Format.fprintf fmt "%s" variant + | VVariant (_, variant, args) -> + Format.fprintf fmt "%s(%a)" variant + (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt ", ") pp) + args + | VClosure _ -> Format.fprintf fmt "" + | VBuiltin (name, _) -> Format.fprintf fmt "" name + | VRef r -> Format.fprintf fmt "ref(%a)" pp !r + +let show v = + let buf = Buffer.create 64 in + let fmt = Format.formatter_of_buffer buf in + pp fmt v; + Format.pp_print_flush fmt (); + Buffer.contents buf + +(** Check value equality *) +let rec equal v1 v2 = + match v1, v2 with + | VUnit, VUnit -> true + | VBool b1, VBool b2 -> b1 = b2 + | VInt i1, VInt i2 -> i1 = i2 + | VFloat f1, VFloat f2 -> f1 = f2 + | VChar c1, VChar c2 -> c1 = c2 + | VString s1, VString s2 -> s1 = s2 + | VTuple vs1, VTuple vs2 -> + List.length vs1 = List.length vs2 && + List.for_all2 equal vs1 vs2 + | VArray a1, VArray a2 -> + Array.length a1 = Array.length a2 && + Array.for_all2 equal a1 a2 + | VRecord f1, VRecord f2 -> + List.length f1 = List.length f2 && + List.for_all2 (fun (k1, v1) (k2, v2) -> k1 = k2 && equal v1 v2) f1 f2 + | VVariant (t1, v1, a1), VVariant (t2, v2, a2) -> + t1 = t2 && v1 = v2 && List.length a1 = List.length a2 && + List.for_all2 equal a1 a2 + | _ -> false + +(** Coerce to bool for conditionals *) +let to_bool = function + | VBool b -> Ok b + | v -> Error (Printf.sprintf "Expected Bool, got %s" (show v)) + +(** Coerce to int *) +let to_int = function + | VInt i -> Ok i + | v -> Error (Printf.sprintf "Expected Int, got %s" (show v)) + +(** Coerce to float *) +let to_float = function + | VFloat f -> Ok f + | VInt i -> Ok (Float.of_int i) + | v -> Error (Printf.sprintf "Expected Float, got %s" (show v)) diff --git a/lib/wasm_binary.ml b/lib/wasm_binary.ml new file mode 100644 index 0000000..3131897 --- /dev/null +++ b/lib/wasm_binary.ml @@ -0,0 +1,677 @@ +(** WebAssembly binary format encoder *) + +open Codegen + +(** Buffer for building binary output *) +type encoder = { + mutable buf: bytes; + mutable pos: int; +} + +let create_encoder () = { + buf = Bytes.create 4096; + pos = 0; +} + +let ensure_capacity enc n = + let required = enc.pos + n in + if required > Bytes.length enc.buf then begin + let new_size = max required (Bytes.length enc.buf * 2) in + let new_buf = Bytes.create new_size in + Bytes.blit enc.buf 0 new_buf 0 enc.pos; + enc.buf <- new_buf + end + +let emit_byte enc b = + ensure_capacity enc 1; + Bytes.set_uint8 enc.buf enc.pos b; + enc.pos <- enc.pos + 1 + +let emit_bytes enc bs = + let len = Bytes.length bs in + ensure_capacity enc len; + Bytes.blit bs 0 enc.buf enc.pos len; + enc.pos <- enc.pos + len + +let emit_string enc s = + emit_bytes enc (Bytes.of_string s) + +(** LEB128 encoding for unsigned integers *) +let emit_u32_leb128 enc value = + let rec loop v = + let byte = v land 0x7f in + let v' = v lsr 7 in + if v' = 0 then + emit_byte enc byte + else begin + emit_byte enc (byte lor 0x80); + loop v' + end + in + loop value + +(** LEB128 encoding for signed integers *) +let emit_s32_leb128 enc value = + let rec loop v = + let byte = v land 0x7f in + let v' = v asr 7 in + let done_ = + (v' = 0 && (byte land 0x40) = 0) || + (v' = -1 && (byte land 0x40) <> 0) + in + if done_ then + emit_byte enc byte + else begin + emit_byte enc (byte lor 0x80); + loop v' + end + in + loop value + +let emit_s64_leb128 enc value = + let rec loop v = + let byte = Int64.to_int (Int64.logand v 0x7fL) in + let v' = Int64.shift_right v 7 in + let done_ = + (Int64.compare v' 0L = 0 && (byte land 0x40) = 0) || + (Int64.compare v' (-1L) = 0 && (byte land 0x40) <> 0) + in + if done_ then + emit_byte enc byte + else begin + emit_byte enc (byte lor 0x80); + loop v' + end + in + loop value + +(** Emit a vector (length-prefixed sequence) *) +let emit_vec enc items emit_item = + emit_u32_leb128 enc (List.length items); + List.iter (emit_item enc) items + +(** Emit a name (length-prefixed UTF-8 string) *) +let emit_name enc name = + let len = String.length name in + emit_u32_leb128 enc len; + emit_string enc name + +(** WASM magic number and version *) +let wasm_magic = "\x00\x61\x73\x6d" (* \0asm *) +let wasm_version = "\x01\x00\x00\x00" (* version 1 *) + +(** Section IDs *) +let section_custom = 0 +let section_type = 1 +let section_import = 2 +let section_function = 3 +let section_table = 4 +let section_memory = 5 +let section_global = 6 +let section_export = 7 +let section_start = 8 +let section_element = 9 +let section_code = 10 +let section_data = 11 + +(** Value type encoding *) +let encode_valtype = function + | I32 -> 0x7f + | I64 -> 0x7e + | F32 -> 0x7d + | F64 -> 0x7c + | Funcref -> 0x70 + | Externref -> 0x6f + +(** Emit a value type *) +let emit_valtype enc ty = + emit_byte enc (encode_valtype ty) + +(** Emit a result type (vector of value types) *) +let emit_resulttype enc types = + emit_vec enc types (fun e t -> emit_valtype e t) + +(** Emit a function type *) +let emit_functype enc ft = + emit_byte enc 0x60; (* func type marker *) + emit_resulttype enc ft.ft_params; + emit_resulttype enc ft.ft_results + +(** Emit limits (for memory/table) *) +let emit_limits enc min max_opt = + match max_opt with + | None -> + emit_byte enc 0x00; + emit_u32_leb128 enc min + | Some max -> + emit_byte enc 0x01; + emit_u32_leb128 enc min; + emit_u32_leb128 enc max + +(** Emit a memory type *) +let emit_memtype enc (min, max) = + emit_limits enc min max + +(** Emit an import *) +let emit_import enc import = + match import with + | ImportFunc (mod_name, func_name, _) -> + emit_name enc mod_name; + emit_name enc func_name; + emit_byte enc 0x00 (* func import *) + (* Type index would be emitted by caller *) + | ImportGlobal (mod_name, glob_name, ty, mutable_) -> + emit_name enc mod_name; + emit_name enc glob_name; + emit_byte enc 0x03; (* global import *) + emit_valtype enc ty; + emit_byte enc (if mutable_ then 0x01 else 0x00) + | ImportMemory (mod_name, mem_name, min, max) -> + emit_name enc mod_name; + emit_name enc mem_name; + emit_byte enc 0x02; (* memory import *) + emit_memtype enc (min, max) + | ImportTable (mod_name, tab_name, min, max) -> + emit_name enc mod_name; + emit_name enc tab_name; + emit_byte enc 0x01; (* table import *) + emit_byte enc 0x70; (* funcref *) + emit_limits enc min max + +(** Emit an export *) +let emit_export enc export = + match export with + | ExportFunc (name, idx) -> + emit_name enc name; + emit_byte enc 0x00; + emit_u32_leb128 enc idx + | ExportGlobal (name, idx) -> + emit_name enc name; + emit_byte enc 0x03; + emit_u32_leb128 enc idx + | ExportMemory name -> + emit_name enc name; + emit_byte enc 0x02; + emit_u32_leb128 enc 0 + | ExportTable name -> + emit_name enc name; + emit_byte enc 0x01; + emit_u32_leb128 enc 0 + +(** Instruction opcodes *) +let opcode_unreachable = 0x00 +let opcode_nop = 0x01 +let opcode_block = 0x02 +let opcode_loop = 0x03 +let opcode_if = 0x04 +let opcode_else = 0x05 +let opcode_end = 0x0b +let opcode_br = 0x0c +let opcode_br_if = 0x0d +let opcode_br_table = 0x0e +let opcode_return = 0x0f +let opcode_call = 0x10 +let opcode_call_indirect = 0x11 + +let opcode_drop = 0x1a +let opcode_select = 0x1b + +let opcode_local_get = 0x20 +let opcode_local_set = 0x21 +let opcode_local_tee = 0x22 +let opcode_global_get = 0x23 +let opcode_global_set = 0x24 + +let opcode_i32_load = 0x28 +let opcode_i64_load = 0x29 +let opcode_f32_load = 0x2a +let opcode_f64_load = 0x2b +let opcode_i32_store = 0x36 +let opcode_i64_store = 0x37 +let opcode_f32_store = 0x38 +let opcode_f64_store = 0x39 +let opcode_memory_size = 0x3f +let opcode_memory_grow = 0x40 + +let opcode_i32_const = 0x41 +let opcode_i64_const = 0x42 +let opcode_f32_const = 0x43 +let opcode_f64_const = 0x44 + +let opcode_i32_eqz = 0x45 +let opcode_i32_eq = 0x46 +let opcode_i32_ne = 0x47 +let opcode_i32_lt_s = 0x48 +let opcode_i32_lt_u = 0x49 +let opcode_i32_gt_s = 0x4a +let opcode_i32_gt_u = 0x4b +let opcode_i32_le_s = 0x4c +let opcode_i32_le_u = 0x4d +let opcode_i32_ge_s = 0x4e +let opcode_i32_ge_u = 0x4f + +let opcode_i64_eqz = 0x50 +let opcode_i64_eq = 0x51 +let opcode_i64_ne = 0x52 +let opcode_i64_lt_s = 0x53 +let opcode_i64_lt_u = 0x54 +let opcode_i64_gt_s = 0x55 +let opcode_i64_gt_u = 0x56 +let opcode_i64_le_s = 0x57 +let opcode_i64_le_u = 0x58 +let opcode_i64_ge_s = 0x59 +let opcode_i64_ge_u = 0x5a + +let opcode_f32_eq = 0x5b +let opcode_f32_ne = 0x5c +let opcode_f32_lt = 0x5d +let opcode_f32_gt = 0x5e +let opcode_f32_le = 0x5f +let opcode_f32_ge = 0x60 + +let opcode_f64_eq = 0x61 +let opcode_f64_ne = 0x62 +let opcode_f64_lt = 0x63 +let opcode_f64_gt = 0x64 +let opcode_f64_le = 0x65 +let opcode_f64_ge = 0x66 + +let opcode_i32_add = 0x6a +let opcode_i32_sub = 0x6b +let opcode_i32_mul = 0x6c +let opcode_i32_div_s = 0x6d +let opcode_i32_div_u = 0x6e +let opcode_i32_rem_s = 0x6f +let opcode_i32_rem_u = 0x70 +let opcode_i32_and = 0x71 +let opcode_i32_or = 0x72 +let opcode_i32_xor = 0x73 +let opcode_i32_shl = 0x74 +let opcode_i32_shr_s = 0x75 +let opcode_i32_shr_u = 0x76 + +let opcode_i64_add = 0x7c +let opcode_i64_sub = 0x7d +let opcode_i64_mul = 0x7e +let opcode_i64_div_s = 0x7f +let opcode_i64_div_u = 0x80 +let opcode_i64_rem_s = 0x81 +let opcode_i64_rem_u = 0x82 +let opcode_i64_and = 0x83 +let opcode_i64_or = 0x84 +let opcode_i64_xor = 0x85 +let opcode_i64_shl = 0x86 +let opcode_i64_shr_s = 0x87 +let opcode_i64_shr_u = 0x88 + +let opcode_f32_add = 0x92 +let opcode_f32_sub = 0x93 +let opcode_f32_mul = 0x94 +let opcode_f32_div = 0x95 +let opcode_f32_neg = 0x8c +let opcode_f32_abs = 0x8b +let opcode_f32_sqrt = 0x91 +let opcode_f32_ceil = 0x8d +let opcode_f32_floor = 0x8e +let opcode_f32_trunc = 0x8f + +let opcode_f64_add = 0xa0 +let opcode_f64_sub = 0xa1 +let opcode_f64_mul = 0xa2 +let opcode_f64_div = 0xa3 +let opcode_f64_neg = 0x9a +let opcode_f64_abs = 0x99 +let opcode_f64_sqrt = 0x9f +let opcode_f64_ceil = 0x9b +let opcode_f64_floor = 0x9c +let opcode_f64_trunc = 0x9d + +(** Emit block type *) +let emit_blocktype enc = function + | None -> emit_byte enc 0x40 (* empty block type *) + | Some I32 -> emit_byte enc 0x7f + | Some I64 -> emit_byte enc 0x7e + | Some F32 -> emit_byte enc 0x7d + | Some F64 -> emit_byte enc 0x7c + | Some _ -> emit_byte enc 0x40 + +(** Emit f32 as IEEE 754 *) +let emit_f32 enc f = + let bits = Int32.bits_of_float f in + for i = 0 to 3 do + emit_byte enc (Int32.to_int (Int32.logand (Int32.shift_right_logical bits (i * 8)) 0xffl)) + done + +(** Emit f64 as IEEE 754 *) +let emit_f64 enc f = + let bits = Int64.bits_of_float f in + for i = 0 to 7 do + emit_byte enc (Int64.to_int (Int64.logand (Int64.shift_right_logical bits (i * 8)) 0xffL)) + done + +(** Emit an instruction *) +let rec emit_instr enc instr = + match instr with + | I32Const n -> emit_byte enc opcode_i32_const; emit_s32_leb128 enc n + | I64Const n -> emit_byte enc opcode_i64_const; emit_s64_leb128 enc n + | F32Const f -> emit_byte enc opcode_f32_const; emit_f32 enc f + | F64Const f -> emit_byte enc opcode_f64_const; emit_f64 enc f + + | LocalGet n -> emit_byte enc opcode_local_get; emit_u32_leb128 enc n + | LocalSet n -> emit_byte enc opcode_local_set; emit_u32_leb128 enc n + | LocalTee n -> emit_byte enc opcode_local_tee; emit_u32_leb128 enc n + | GlobalGet n -> emit_byte enc opcode_global_get; emit_u32_leb128 enc n + | GlobalSet n -> emit_byte enc opcode_global_set; emit_u32_leb128 enc n + + | I32Load (align, offset) -> + emit_byte enc opcode_i32_load; + emit_u32_leb128 enc align; + emit_u32_leb128 enc offset + | I32Store (align, offset) -> + emit_byte enc opcode_i32_store; + emit_u32_leb128 enc align; + emit_u32_leb128 enc offset + | I64Load (align, offset) -> + emit_byte enc opcode_i64_load; + emit_u32_leb128 enc align; + emit_u32_leb128 enc offset + | I64Store (align, offset) -> + emit_byte enc opcode_i64_store; + emit_u32_leb128 enc align; + emit_u32_leb128 enc offset + | F32Load (align, offset) -> + emit_byte enc opcode_f32_load; + emit_u32_leb128 enc align; + emit_u32_leb128 enc offset + | F32Store (align, offset) -> + emit_byte enc opcode_f32_store; + emit_u32_leb128 enc align; + emit_u32_leb128 enc offset + | F64Load (align, offset) -> + emit_byte enc opcode_f64_load; + emit_u32_leb128 enc align; + emit_u32_leb128 enc offset + | F64Store (align, offset) -> + emit_byte enc opcode_f64_store; + emit_u32_leb128 enc align; + emit_u32_leb128 enc offset + + | MemorySize -> emit_byte enc opcode_memory_size; emit_byte enc 0x00 + | MemoryGrow -> emit_byte enc opcode_memory_grow; emit_byte enc 0x00 + + | I32Add -> emit_byte enc opcode_i32_add + | I32Sub -> emit_byte enc opcode_i32_sub + | I32Mul -> emit_byte enc opcode_i32_mul + | I32DivS -> emit_byte enc opcode_i32_div_s + | I32DivU -> emit_byte enc opcode_i32_div_u + | I32RemS -> emit_byte enc opcode_i32_rem_s + | I32RemU -> emit_byte enc opcode_i32_rem_u + | I32And -> emit_byte enc opcode_i32_and + | I32Or -> emit_byte enc opcode_i32_or + | I32Xor -> emit_byte enc opcode_i32_xor + | I32Shl -> emit_byte enc opcode_i32_shl + | I32ShrS -> emit_byte enc opcode_i32_shr_s + | I32ShrU -> emit_byte enc opcode_i32_shr_u + | I32Eqz -> emit_byte enc opcode_i32_eqz + | I32Eq -> emit_byte enc opcode_i32_eq + | I32Ne -> emit_byte enc opcode_i32_ne + | I32LtS -> emit_byte enc opcode_i32_lt_s + | I32LtU -> emit_byte enc opcode_i32_lt_u + | I32GtS -> emit_byte enc opcode_i32_gt_s + | I32GtU -> emit_byte enc opcode_i32_gt_u + | I32LeS -> emit_byte enc opcode_i32_le_s + | I32LeU -> emit_byte enc opcode_i32_le_u + | I32GeS -> emit_byte enc opcode_i32_ge_s + | I32GeU -> emit_byte enc opcode_i32_ge_u + + | I64Add -> emit_byte enc opcode_i64_add + | I64Sub -> emit_byte enc opcode_i64_sub + | I64Mul -> emit_byte enc opcode_i64_mul + | I64DivS -> emit_byte enc opcode_i64_div_s + | I64DivU -> emit_byte enc opcode_i64_div_u + | I64RemS -> emit_byte enc opcode_i64_rem_s + | I64RemU -> emit_byte enc opcode_i64_rem_u + | I64And -> emit_byte enc opcode_i64_and + | I64Or -> emit_byte enc opcode_i64_or + | I64Xor -> emit_byte enc opcode_i64_xor + | I64Shl -> emit_byte enc opcode_i64_shl + | I64ShrS -> emit_byte enc opcode_i64_shr_s + | I64ShrU -> emit_byte enc opcode_i64_shr_u + | I64Eqz -> emit_byte enc opcode_i64_eqz + | I64Eq -> emit_byte enc opcode_i64_eq + | I64Ne -> emit_byte enc opcode_i64_ne + | I64LtS -> emit_byte enc opcode_i64_lt_s + | I64LtU -> emit_byte enc opcode_i64_lt_u + | I64GtS -> emit_byte enc opcode_i64_gt_s + | I64GtU -> emit_byte enc opcode_i64_gt_u + | I64LeS -> emit_byte enc opcode_i64_le_s + | I64LeU -> emit_byte enc opcode_i64_le_u + | I64GeS -> emit_byte enc opcode_i64_ge_s + | I64GeU -> emit_byte enc opcode_i64_ge_u + + | F32Add -> emit_byte enc opcode_f32_add + | F32Sub -> emit_byte enc opcode_f32_sub + | F32Mul -> emit_byte enc opcode_f32_mul + | F32Div -> emit_byte enc opcode_f32_div + | F32Eq -> emit_byte enc opcode_f32_eq + | F32Ne -> emit_byte enc opcode_f32_ne + | F32Lt -> emit_byte enc opcode_f32_lt + | F32Gt -> emit_byte enc opcode_f32_gt + | F32Le -> emit_byte enc opcode_f32_le + | F32Ge -> emit_byte enc opcode_f32_ge + | F32Neg -> emit_byte enc opcode_f32_neg + | F32Abs -> emit_byte enc opcode_f32_abs + | F32Sqrt -> emit_byte enc opcode_f32_sqrt + | F32Ceil -> emit_byte enc opcode_f32_ceil + | F32Floor -> emit_byte enc opcode_f32_floor + | F32Trunc -> emit_byte enc opcode_f32_trunc + + | F64Add -> emit_byte enc opcode_f64_add + | F64Sub -> emit_byte enc opcode_f64_sub + | F64Mul -> emit_byte enc opcode_f64_mul + | F64Div -> emit_byte enc opcode_f64_div + | F64Eq -> emit_byte enc opcode_f64_eq + | F64Ne -> emit_byte enc opcode_f64_ne + | F64Lt -> emit_byte enc opcode_f64_lt + | F64Gt -> emit_byte enc opcode_f64_gt + | F64Le -> emit_byte enc opcode_f64_le + | F64Ge -> emit_byte enc opcode_f64_ge + | F64Neg -> emit_byte enc opcode_f64_neg + | F64Abs -> emit_byte enc opcode_f64_abs + | F64Sqrt -> emit_byte enc opcode_f64_sqrt + | F64Ceil -> emit_byte enc opcode_f64_ceil + | F64Floor -> emit_byte enc opcode_f64_floor + | F64Trunc -> emit_byte enc opcode_f64_trunc + + (* Conversions *) + | I32WrapI64 -> emit_byte enc 0xa7 + | I64ExtendI32S -> emit_byte enc 0xac + | I64ExtendI32U -> emit_byte enc 0xad + | F32ConvertI32S -> emit_byte enc 0xb2 + | F32ConvertI32U -> emit_byte enc 0xb3 + | F32ConvertI64S -> emit_byte enc 0xb4 + | F32ConvertI64U -> emit_byte enc 0xb5 + | F64ConvertI32S -> emit_byte enc 0xb7 + | F64ConvertI32U -> emit_byte enc 0xb8 + | F64ConvertI64S -> emit_byte enc 0xb9 + | F64ConvertI64U -> emit_byte enc 0xba + | I32TruncF32S -> emit_byte enc 0xa8 + | I32TruncF32U -> emit_byte enc 0xa9 + | I32TruncF64S -> emit_byte enc 0xaa + | I32TruncF64U -> emit_byte enc 0xab + | I64TruncF32S -> emit_byte enc 0xae + | I64TruncF32U -> emit_byte enc 0xaf + | I64TruncF64S -> emit_byte enc 0xb0 + | I64TruncF64U -> emit_byte enc 0xb1 + | F32DemoteF64 -> emit_byte enc 0xb6 + | F64PromoteF32 -> emit_byte enc 0xbb + | I32ReinterpretF32 -> emit_byte enc 0xbc + | I64ReinterpretF64 -> emit_byte enc 0xbd + | F32ReinterpretI32 -> emit_byte enc 0xbe + | F64ReinterpretI64 -> emit_byte enc 0xbf + + (* Control flow *) + | Unreachable -> emit_byte enc opcode_unreachable + | Nop -> emit_byte enc opcode_nop + + | Block (ty, instrs) -> + emit_byte enc opcode_block; + emit_blocktype enc ty; + List.iter (emit_instr enc) instrs; + emit_byte enc opcode_end + + | Loop (ty, instrs) -> + emit_byte enc opcode_loop; + emit_blocktype enc ty; + List.iter (emit_instr enc) instrs; + emit_byte enc opcode_end + + | If (ty, then_instrs, else_instrs) -> + emit_byte enc opcode_if; + emit_blocktype enc ty; + List.iter (emit_instr enc) then_instrs; + (match else_instrs with + | Some instrs -> + emit_byte enc opcode_else; + List.iter (emit_instr enc) instrs + | None -> ()); + emit_byte enc opcode_end + + | Br n -> emit_byte enc opcode_br; emit_u32_leb128 enc n + | BrIf n -> emit_byte enc opcode_br_if; emit_u32_leb128 enc n + | BrTable (labels, default) -> + emit_byte enc opcode_br_table; + emit_vec enc labels (fun e n -> emit_u32_leb128 e n); + emit_u32_leb128 enc default + | Return -> emit_byte enc opcode_return + | Call n -> emit_byte enc opcode_call; emit_u32_leb128 enc n + | CallIndirect n -> emit_byte enc opcode_call_indirect; emit_u32_leb128 enc n; emit_byte enc 0x00 + + | Drop -> emit_byte enc opcode_drop + | Select -> emit_byte enc opcode_select + +(** Emit a code section entry (function body) *) +let emit_code_entry enc func = + let body_enc = create_encoder () in + (* Locals: group by type *) + let local_groups = if func.fn_locals = [] then [] else + let rec group acc current_ty count = function + | [] -> List.rev ((count, current_ty) :: acc) + | ty :: rest when ty = current_ty -> + group acc current_ty (count + 1) rest + | ty :: rest -> + group ((count, current_ty) :: acc) ty 1 rest + in + match func.fn_locals with + | [] -> [] + | ty :: rest -> group [] ty 1 rest + in + emit_vec body_enc local_groups (fun e (count, ty) -> + emit_u32_leb128 e count; + emit_valtype e ty + ); + (* Instructions *) + List.iter (emit_instr body_enc) func.fn_body; + emit_byte body_enc opcode_end; + (* Emit size and body *) + emit_u32_leb128 enc body_enc.pos; + emit_bytes enc (Bytes.sub body_enc.buf 0 body_enc.pos) + +(** Emit a section *) +let emit_section enc section_id content_fn = + let content_enc = create_encoder () in + content_fn content_enc; + if content_enc.pos > 0 then begin + emit_byte enc section_id; + emit_u32_leb128 enc content_enc.pos; + emit_bytes enc (Bytes.sub content_enc.buf 0 content_enc.pos) + end + +(** Encode a WASM module to binary *) +let encode_module (module_ : wasm_module) = + let enc = create_encoder () in + + (* Magic and version *) + emit_string enc wasm_magic; + emit_string enc wasm_version; + + (* Type section *) + emit_section enc section_type (fun e -> + emit_vec e module_.mod_types emit_functype + ); + + (* Import section *) + if module_.mod_imports <> [] then + emit_section enc section_import (fun e -> + emit_u32_leb128 e (List.length module_.mod_imports); + List.iteri (fun i import -> + match import with + | ImportFunc (mod_name, func_name, _) -> + emit_name e mod_name; + emit_name e func_name; + emit_byte e 0x00; (* func import *) + emit_u32_leb128 e i (* type index - simplified *) + | _ -> emit_import e import + ) module_.mod_imports + ); + + (* Function section (type indices) *) + let num_imports = List.length (List.filter (function ImportFunc _ -> true | _ -> false) module_.mod_imports) in + emit_section enc section_function (fun e -> + emit_vec e module_.mod_funcs (fun e' _fn -> + emit_u32_leb128 e' 0 (* type index - simplified *) + ) + ); + + (* Memory section *) + (match module_.mod_memory with + | Some (min, max) -> + emit_section enc section_memory (fun e -> + emit_u32_leb128 e 1; (* 1 memory *) + emit_memtype e (min, max) + ) + | None -> ()); + + (* Export section *) + if module_.mod_exports <> [] then + emit_section enc section_export (fun e -> + emit_vec e module_.mod_exports (fun e' export -> + match export with + | ExportFunc (name, idx) -> + emit_name e' name; + emit_byte e' 0x00; + emit_u32_leb128 e' (num_imports + idx) + | ExportMemory name -> + emit_name e' name; + emit_byte e' 0x02; + emit_u32_leb128 e' 0 + | _ -> emit_export e' export + ) + ); + + (* Code section *) + emit_section enc section_code (fun e -> + emit_vec e module_.mod_funcs emit_code_entry + ); + + (* Data section *) + if module_.mod_data <> [] then + emit_section enc section_data (fun e -> + emit_vec e module_.mod_data (fun e' (offset, data) -> + emit_byte e' 0x00; (* active, memory 0 *) + emit_byte e' opcode_i32_const; + emit_s32_leb128 e' offset; + emit_byte e' opcode_end; + emit_u32_leb128 e' (String.length data); + emit_string e' data + ) + ); + + Bytes.sub enc.buf 0 enc.pos + +(** Write module to file *) +let write_to_file filename module_ = + let binary = encode_module module_ in + let oc = open_out_bin filename in + output_bytes oc binary; + close_out oc diff --git a/library/README.md b/library/README.md new file mode 100644 index 0000000..46846b3 --- /dev/null +++ b/library/README.md @@ -0,0 +1,87 @@ +# AffineScript Standard Library + +This directory contains the AffineScript standard library, organized into two main sections: + +## Common Library (`common/`) + +Language-agnostic utilities that could be used by any language: + +- **prelude.afs** - Core types (Option, Result, Ordering), fundamental traits (Eq, Ord, Clone, Display, Iterator) +- **collections.afs** - Data structures (Vec, HashMap, HashSet, LinkedList, BinaryHeap, Deque) +- **io.afs** - IO effect, file operations, streams, buffered readers/writers +- **async.afs** - Async/await primitives, futures, channels, synchronization +- **string.afs** - String manipulation, character utilities, StringBuilder +- **math.afs** - Mathematical constants and functions, statistics, random numbers +- **time.afs** - Duration, Instant, SystemTime, DateTime, timing utilities +- **sync.afs** - Atomic types, spinlocks, RwLock, barriers, latches + +## AffineScript Library (`affinescript/`) + +AffineScript-specific features leveraging the type system: + +- **linear.afs** - Linear/affine type utilities (LinearBox, Token, session types) +- **effects.afs** - Algebraic effect handlers (Reader, Writer, State, Exception, NonDet) +- **ownership.afs** - Ownership helpers (Box, Rc, Weak, Cow, RefCell) +- **refinements.afs** - Refinement types and dependent types (Positive, NonZero, Vec[N, T], Matrix) + +## Usage + +```afs +// Import everything from common +use Common::*; + +// Import AffineScript prelude +use AffineScript::prelude::*; + +// Import specific modules +use Common.Math::{sqrt, PI}; +use AffineScript.Linear::LinearBox; +use AffineScript.Refinements::Positive; +``` + +## Design Philosophy + +1. **Type Safety First** - Leverage AffineScript's type system to prevent errors at compile time +2. **Zero-Cost Abstractions** - Abstractions that compile away to efficient code +3. **Effect Tracking** - IO and other effects are explicit in function signatures +4. **Resource Safety** - Linear types ensure resources are properly managed + +## Examples + +### Option and Result +```afs +let x: Option[Int] = Some(42); +let y = x.map(|n| n * 2).unwrap_or(0); + +let result: Result[Int, String] = Ok(10); +let value = result?; // Propagates error +``` + +### Linear Resources +```afs +let handle = ResourceHandle::new(file, |f| f.close()); +handle.use_with(|f| f.write("hello")); +handle.close()?; // Must be called exactly once +``` + +### Effects +```afs +fn computation() -> Int / State[Int], Except[String] { + let current = State::get(); + if current < 0 { + Except::throw("negative state"); + } + State::put(current + 1); + current +} +``` + +### Refinement Types +```afs +fn safe_div(a: Int, b: NonZero) -> Int { + a / b // Statically guaranteed b != 0 +} + +let v: Vec[3, Int] = Vec::from_array([1, 2, 3]).unwrap(); +let first = v.get::<0>(); // Compile-time bounds check +``` diff --git a/library/affinescript/effects.afs b/library/affinescript/effects.afs new file mode 100644 index 0000000..9f22523 --- /dev/null +++ b/library/affinescript/effects.afs @@ -0,0 +1,320 @@ +// AffineScript Specific Library - Effects +// Utilities for working with algebraic effects + +module AffineScript.Effects; + +use Common.Prelude::*; + +// ============================================================================ +// Effect Row Types +// ============================================================================ + +/// Empty effect row (pure computation) +pub type Pure = (); + +/// Single effect +pub type Single[E] = E; + +/// Effect union (E1 | E2) +pub type Union[E1, E2] = (E1, E2); + +// ============================================================================ +// Effect Utilities +// ============================================================================ + +/// Lift a pure computation into an effectful context +pub fn lift_pure[T, E](value: T) -> T / E { + value +} + +/// Embed a smaller effect set into a larger one +pub fn embed[T, E1, E2](computation: fn() -> T / E1) -> T / E2 +where + E1: SubEffect[E2], +{ + computation() +} + +/// Check if a computation is pure (no effects) +pub fn is_pure[T](f: fn() -> T) -> Bool { + true // Type system ensures this +} + +// ============================================================================ +// Effect Handlers +// ============================================================================ + +/// Handler for a single effect +pub struct Handler[E, R] { + /// Handle return value + return_handler: fn(R) -> R, + /// Handle effect operations + operations: E, +} + +impl[E, R] Handler[E, R] { + pub fn new(return_handler: fn(R) -> R, operations: E) -> Handler[E, R] { + Handler { return_handler, operations } + } +} + +/// Run computation with handler +pub fn handle[T, E, R]( + handler: Handler[E, R], + computation: fn() -> T / E +) -> R { + // Would invoke the handler infrastructure + panic!("implemented by compiler") +} + +// ============================================================================ +// Common Effect Patterns +// ============================================================================ + +/// Reader effect - access to read-only environment +pub effect Reader[R] { + /// Get the environment + fn ask() -> R; + + /// Run with locally modified environment + fn local[A](f: fn(R) -> R, computation: fn() -> A / Reader[R]) -> A; +} + +/// Reader handler +pub fn run_reader[R, T](env: R, computation: fn() -> T / Reader[R]) -> T { + handle( + Handler::new( + |x| x, + ReaderOps { env } + ), + computation + ) +} + +struct ReaderOps[R] { + env: R, +} + +/// Writer effect - accumulate output +pub effect Writer[W] { + /// Append to output + fn tell(w: W) -> (); + + /// Get output so far and reset + fn listen() -> W; +} + +/// Writer handler (requires W: Monoid) +pub fn run_writer[W: Monoid, T](computation: fn() -> T / Writer[W]) -> (T, W) { + // Would accumulate writes + panic!("implemented by compiler") +} + +/// State effect - mutable state +pub effect State[S] { + /// Get current state + fn get() -> S; + + /// Set new state + fn put(s: S) -> (); + + /// Modify state + fn modify(f: fn(S) -> S) -> (); +} + +/// State handler +pub fn run_state[S, T](initial: S, computation: fn() -> T / State[S]) -> (T, S) { + // Would thread state through + panic!("implemented by compiler") +} + +/// Execute state, returning only the result +pub fn eval_state[S, T](initial: S, computation: fn() -> T / State[S]) -> T { + let (result, _) = run_state(initial, computation); + result +} + +/// Execute state, returning only the final state +pub fn exec_state[S, T](initial: S, computation: fn() -> T / State[S]) -> S { + let (_, state) = run_state(initial, computation); + state +} + +/// Exception effect +pub effect Except[E] { + /// Throw an exception + fn throw(e: E) -> Never; + + /// Catch and handle exception + fn catch[A](computation: fn() -> A / Except[E], handler: fn(E) -> A) -> A; +} + +/// Exception handler +pub fn run_except[E, T](computation: fn() -> T / Except[E]) -> Result[T, E] { + // Would catch exceptions + panic!("implemented by compiler") +} + +/// Non-determinism effect +pub effect NonDet { + /// Choose between two branches + fn choose() -> Bool; + + /// Fail computation (backtrack) + fn fail() -> Never; +} + +/// Run non-deterministically, collecting all results +pub fn run_list[T](computation: fn() -> T / NonDet) -> Array[T] { + // Would collect all branches + panic!("implemented by compiler") +} + +/// Run non-deterministically, returning first success +pub fn run_maybe[T](computation: fn() -> T / NonDet) -> Option[T] { + // Would return first non-failing branch + panic!("implemented by compiler") +} + +// ============================================================================ +// Effect Composition +// ============================================================================ + +/// Compose two handlers +pub fn compose_handlers[E1, E2, R]( + h1: Handler[E1, R], + h2: Handler[E2, R] +) -> Handler[Union[E1, E2], R] { + // Would compose effect handlers + panic!("implemented by compiler") +} + +/// Run multiple effects with their handlers +pub fn run_effects[T, E1, E2, R]( + h1: Handler[E1, R], + h2: Handler[E2, R], + computation: fn() -> T / Union[E1, E2] +) -> R { + handle(compose_handlers(h1, h2), computation) +} + +// ============================================================================ +// Continuation Utilities +// ============================================================================ + +/// Delimited continuation +pub struct Cont[R, A] { + run: fn(fn(A) -> R) -> R, +} + +impl[R, A] Cont[R, A] { + pub fn new(f: fn(fn(A) -> R) -> R) -> Cont[R, A] { + Cont { run: f } + } + + pub fn run(self: Self, k: fn(A) -> R) -> R { + (self.run)(k) + } + + pub fn map[B](self: Self, f: fn(A) -> B) -> Cont[R, B] { + Cont::new(|k| self.run(|a| k(f(a)))) + } + + pub fn flat_map[B](self: Self, f: fn(A) -> Cont[R, B]) -> Cont[R, B] { + Cont::new(|k| self.run(|a| f(a).run(k))) + } +} + +/// Call with current continuation +pub fn call_cc[R, A](f: fn(fn(A) -> Cont[R, A]) -> Cont[R, A]) -> Cont[R, A] { + Cont::new(|k| { + f(|a| Cont::new(|_| k(a))).run(k) + }) +} + +/// Reset/shift delimited control +pub fn reset[A](computation: Cont[A, A]) -> A { + computation.run(|x| x) +} + +pub fn shift[R, A](f: fn(fn(A) -> R) -> Cont[R, R]) -> Cont[R, A] { + Cont::new(|k| reset(f(k))) +} + +// ============================================================================ +// Effect Polymorphism +// ============================================================================ + +/// Marker trait for effects that can be handled +pub trait Handleable { + type Handler; + type Result; + + fn default_handler() -> Self::Handler; +} + +/// Marker trait for effect subtyping +pub trait SubEffect[Super] {} + +// Every effect is a subeffect of itself +impl[E] SubEffect[E] for E {} + +// Pure is a subeffect of everything +impl[E] SubEffect[E] for Pure {} + +/// Run a computation that may have effects, handling them with defaults +pub fn run_with_defaults[T, E: Handleable](computation: fn() -> T / E) -> E::Result { + handle(E::default_handler(), computation) +} + +// ============================================================================ +// Effect Inference Helpers +// ============================================================================ + +/// Annotate a function with its effects (for documentation) +pub fn with_effects[T, E](f: fn() -> T / E) -> fn() -> T / E { + f +} + +/// Mask effects (pretend they're not there - unsafe!) +/// Only use when you know the effects are handled externally +pub unsafe fn mask_effects[T, E](computation: fn() -> T / E) -> fn() -> T { + || panic!("effects must be handled") +} + +/// Assert a computation is pure +pub fn assert_pure[T](computation: fn() -> T) -> T { + computation() +} + +// ============================================================================ +// Monoid Trait for Writer +// ============================================================================ + +/// Monoid - type with identity and associative operation +pub trait Monoid { + fn empty() -> Self; + fn append(self: Self, other: Self) -> Self; +} + +impl Monoid for String { + fn empty() -> String { "" } + fn append(self: String, other: String) -> String { self + other } +} + +impl[T] Monoid for Array[T] { + fn empty() -> Array[T] { [] } + fn append(self: Array[T], other: Array[T]) -> Array[T] { + let mut result = self; + for item in other { + result.push(item); + } + result + } +} + +impl Monoid for Int { + fn empty() -> Int { 0 } + fn append(self: Int, other: Int) -> Int { self + other } +} + diff --git a/library/affinescript/linear.afs b/library/affinescript/linear.afs new file mode 100644 index 0000000..e471229 --- /dev/null +++ b/library/affinescript/linear.afs @@ -0,0 +1,425 @@ +// AffineScript Specific Library - Linear Types +// Utilities for working with linear and affine types + +module AffineScript.Linear; + +use Common.Prelude::*; + +// ============================================================================ +// Quantity Annotations +// ============================================================================ + +/// Quantity type - represents usage annotation +/// 0 = erased (compile-time only) +/// 1 = linear (exactly once) +/// ω = unrestricted (any number of times) +pub enum Quantity { + /// Erased - exists only at compile time + Zero, + /// Linear - must be used exactly once + One, + /// Unrestricted - can be used any number of times + Many, +} + +impl Quantity { + /// Multiply quantities (semiring multiplication) + /// 0 * q = 0 + /// 1 * q = q + /// ω * q = ω (if q ≠ 0) + pub fn mul(self: Self, other: Self) -> Self { + match (self, other) { + (Zero, _) | (_, Zero) => Zero, + (One, q) | (q, One) => q, + (Many, Many) => Many, + } + } + + /// Add quantities (semiring addition) + /// 0 + q = q + /// 1 + 1 = ω + /// ω + q = ω + pub fn add(self: Self, other: Self) -> Self { + match (self, other) { + (Zero, q) | (q, Zero) => q, + (One, One) => Many, + (Many, _) | (_, Many) => Many, + } + } + + /// Check if q1 <= q2 in the quantity lattice + pub fn subquantity(self: Self, other: Self) -> Bool { + match (self, other) { + (Zero, _) => true, + (One, One) | (One, Many) => true, + (Many, Many) => true, + _ => false, + } + } +} + +// ============================================================================ +// Linear Box - Enforce linear usage +// ============================================================================ + +/// A box that enforces linear usage of its contents +/// Once opened, it cannot be used again +#[linear] +pub struct LinearBox[T] { + value: T, +} + +impl[T] LinearBox[T] { + /// Create a new linear box + pub fn new(value: T) -> LinearBox[T] { + LinearBox { value } + } + + /// Consume the box and get the value (can only be called once) + pub fn open(self: Self) -> T { + self.value + } + + /// Transform the contents + pub fn map[U](self: Self, f: fn(T) -> U) -> LinearBox[U] { + LinearBox::new(f(self.value)) + } + + /// Chain operations + pub fn and_then[U](self: Self, f: fn(T) -> LinearBox[U]) -> LinearBox[U] { + f(self.value) + } +} + +// ============================================================================ +// Affine Box - At most once +// ============================================================================ + +/// A box that enforces affine usage (at most once, can be dropped) +#[affine] +pub struct AffineBox[T] { + value: T, +} + +impl[T] AffineBox[T] { + pub fn new(value: T) -> AffineBox[T] { + AffineBox { value } + } + + /// Consume and get value + pub fn take(self: Self) -> T { + self.value + } + + /// Try to take, returning None if already consumed + pub fn try_take(self: Self) -> Option[T] { + Some(self.value) + } + + /// Transform contents + pub fn map[U](self: Self, f: fn(T) -> U) -> AffineBox[U] { + AffineBox::new(f(self.value)) + } +} + +// ============================================================================ +// Unique Reference +// ============================================================================ + +/// A unique (non-aliasable) reference +/// Guarantees exclusive access to the referenced value +#[linear] +pub struct Unique[T] { + ptr: own T, +} + +impl[T] Unique[T] { + /// Create a unique reference from owned value + pub fn new(value: T) -> Unique[T] { + Unique { ptr: value } + } + + /// Get immutable reference + pub fn get(self: ref Self) -> ref T { + &self.ptr + } + + /// Get mutable reference + pub fn get_mut(self: mut Self) -> mut T { + &mut self.ptr + } + + /// Consume and return owned value + pub fn into_inner(self: Self) -> T { + self.ptr + } + + /// Replace the value, returning old one + pub fn replace(self: mut Self, value: T) -> T { + let old = self.ptr; + self.ptr = value; + old + } + + /// Swap with another Unique + pub fn swap(self: mut Self, other: mut Unique[T]) { + let tmp = self.ptr; + self.ptr = other.ptr; + other.ptr = tmp; + } +} + +// ============================================================================ +// Resource Handle +// ============================================================================ + +/// A handle to a resource that must be explicitly closed +/// Prevents resource leaks at compile time +#[linear] +pub struct ResourceHandle[T, E] { + resource: T, + close: fn(T) -> Result[(), E], +} + +impl[T, E] ResourceHandle[T, E] { + /// Create a new resource handle + pub fn new(resource: T, close: fn(T) -> Result[(), E]) -> ResourceHandle[T, E] { + ResourceHandle { resource, close } + } + + /// Use the resource with a callback + pub fn use_with[R](self: ref Self, f: fn(ref T) -> R) -> R { + f(&self.resource) + } + + /// Use mutably with a callback + pub fn use_mut_with[R](self: mut Self, f: fn(mut T) -> R) -> R { + f(&mut self.resource) + } + + /// Close the resource (must be called exactly once) + pub fn close(self: Self) -> Result[(), E] { + (self.close)(self.resource) + } +} + +/// Use a resource and automatically close it +pub fn with_resource[T, E, R]( + resource: T, + close: fn(T) -> Result[(), E], + f: fn(ref T) -> R +) -> Result[R, E] { + let handle = ResourceHandle::new(resource, close); + let result = handle.use_with(f); + handle.close()?; + Ok(result) +} + +// ============================================================================ +// Linear Token +// ============================================================================ + +/// A token that represents permission to perform an action once +/// Cannot be duplicated, must be consumed +#[linear] +pub struct Token[phantom Action] { + _phantom: (), +} + +impl[Action] Token[Action] { + /// Create a new token (internal use) + pub fn mint() -> Token[Action] { + Token { _phantom: () } + } + + /// Consume the token and get proof of consumption + pub fn consume(self: Self) -> Consumed[Action] { + Consumed { _phantom: () } + } +} + +/// Proof that a token was consumed +pub struct Consumed[phantom Action] { + _phantom: (), +} + +// ============================================================================ +// Linear State Machine +// ============================================================================ + +/// State machine with linear state transitions +/// Ensures proper state progression at compile time +pub trait LinearState { + /// The next valid state(s) + type Next; +} + +/// State machine wrapper +#[linear] +pub struct StateMachine[S: LinearState] { + state: S, +} + +impl[S: LinearState] StateMachine[S] { + pub fn new(initial: S) -> StateMachine[S] { + StateMachine { state: initial } + } + + /// Transition to next state (consumes current state) + pub fn transition[N](self: Self, next: N) -> StateMachine[N] + where + N: LinearState, + S::Next: Contains[N], + { + StateMachine { state: next } + } + + /// Access current state + pub fn state(self: ref Self) -> ref S { + &self.state + } + + /// Complete the state machine (must be in final state) + pub fn complete(self: Self) -> S + where + S::Next: IsEmpty, + { + self.state + } +} + +// ============================================================================ +// Linear Channel +// ============================================================================ + +/// One-shot channel - can send exactly one value +pub struct OneshotSender[T] { + #[linear] + inner: LinearBox[fn(T) -> ()], +} + +pub struct OneshotReceiver[T] { + #[linear] + inner: LinearBox[fn() -> T], +} + +/// Create a one-shot channel +pub fn oneshot[T]() -> (OneshotSender[T], OneshotReceiver[T]) { + // Implementation would set up single-use communication + let sender = OneshotSender { inner: LinearBox::new(|_| ()) }; + let receiver = OneshotReceiver { inner: LinearBox::new(|| panic!("not implemented")) }; + (sender, receiver) +} + +impl[T] OneshotSender[T] { + /// Send a value (consumes the sender) + pub fn send(self: Self, value: T) { + let f = self.inner.open(); + f(value); + } +} + +impl[T] OneshotReceiver[T] { + /// Receive the value (consumes the receiver) + pub fn recv(self: Self) -> T { + let f = self.inner.open(); + f() + } +} + +// ============================================================================ +// Session Types (Simplified) +// ============================================================================ + +/// A session-typed channel endpoint +/// Ensures protocol compliance at compile time +pub struct Session[Protocol] { + // Implementation hidden + _phantom: Protocol, +} + +/// Send action in a session +pub struct Send[T, Cont] { + _phantom: (T, Cont), +} + +/// Receive action in a session +pub struct Recv[T, Cont] { + _phantom: (T, Cont), +} + +/// End of session +pub struct End {} + +impl[T, Cont] Session[Send[T, Cont]] { + /// Send a value and continue with the rest of the protocol + pub fn send(self: Self, value: T) -> Session[Cont] { + Session { _phantom: panic!("not implemented") } + } +} + +impl[T, Cont] Session[Recv[T, Cont]] { + /// Receive a value and continue with the rest of the protocol + pub fn recv(self: Self) -> (T, Session[Cont]) { + panic!("not implemented") + } +} + +impl Session[End] { + /// Close the session (must be at End state) + pub fn close(self: Self) { + // Session complete + } +} + +// ============================================================================ +// Borrowing Utilities +// ============================================================================ + +/// Reborrow a reference with a shorter lifetime +pub fn reborrow[T](r: ref T) -> ref T { + r +} + +/// Reborrow mutably +pub fn reborrow_mut[T](r: mut T) -> mut T { + r +} + +/// Convert owned to reference for duration of scope +pub fn as_ref[T, R](owned: T, f: fn(ref T) -> R) -> R { + f(&owned) +} + +/// Convert owned to mutable reference for duration of scope +pub fn as_mut[T, R](mut owned: T, f: fn(mut T) -> R) -> R { + f(&mut owned) +} + +// ============================================================================ +// Phantom Data +// ============================================================================ + +/// Zero-sized type for carrying type information without runtime cost +pub struct PhantomData[T] { + // Zero-sized, exists only for type checking +} + +impl[T] PhantomData[T] { + pub fn new() -> PhantomData[T] { + PhantomData {} + } +} + +impl[T] Clone for PhantomData[T] { + fn clone(self: ref Self) -> Self { + PhantomData {} + } +} + +impl[T] Default for PhantomData[T] { + fn default() -> Self { + PhantomData {} + } +} + diff --git a/library/affinescript/mod.afs b/library/affinescript/mod.afs new file mode 100644 index 0000000..5015cb6 --- /dev/null +++ b/library/affinescript/mod.afs @@ -0,0 +1,34 @@ +// AffineScript Specific Library +// Re-exports all AffineScript-specific modules + +module AffineScript; + +pub use AffineScript.Linear::*; +pub use AffineScript.Effects::*; +pub use AffineScript.Ownership::*; +pub use AffineScript.Refinements::*; + +/// Library version +pub const VERSION: String = "0.1.0"; + +/// Library name +pub const NAME: String = "affinescript"; + +// ============================================================================ +// Prelude for AffineScript +// ============================================================================ + +/// AffineScript prelude - commonly used types and traits +pub mod prelude { + // From Common library + pub use Common.Prelude::*; + pub use Common.Collections::{Vec, HashMap, HashSet}; + pub use Common.IO::{IO, IOError, File, Read, Write}; + pub use Common.Async::{Future, Async, spawn, block_on}; + + // AffineScript-specific + pub use AffineScript.Linear::{LinearBox, AffineBox, Unique, Token}; + pub use AffineScript.Effects::{Reader, Writer, State, Except, NonDet}; + pub use AffineScript.Ownership::{Own, Box, Rc, Weak, Cow}; + pub use AffineScript.Refinements::{Positive, NonNegative, NonZero, NonEmpty}; +} diff --git a/library/affinescript/ownership.afs b/library/affinescript/ownership.afs new file mode 100644 index 0000000..2d8565a --- /dev/null +++ b/library/affinescript/ownership.afs @@ -0,0 +1,476 @@ +// AffineScript Specific Library - Ownership +// Utilities for working with ownership and borrowing + +module AffineScript.Ownership; + +use Common.Prelude::*; + +// ============================================================================ +// Ownership Types +// ============================================================================ + +/// Owned value - full ownership with move semantics +pub struct Own[T] { + value: T, +} + +impl[T] Own[T] { + /// Create owned value + pub fn new(value: T) -> Own[T] { + Own { value } + } + + /// Move out of Own wrapper + pub fn into_inner(self: Self) -> T { + self.value + } + + /// Borrow immutably + pub fn borrow(self: ref Self) -> ref T { + &self.value + } + + /// Borrow mutably + pub fn borrow_mut(self: mut Self) -> mut T { + &mut self.value + } + + /// Map over owned value + pub fn map[U](self: Self, f: fn(T) -> U) -> Own[U] { + Own::new(f(self.value)) + } +} + +// ============================================================================ +// Reference Types +// ============================================================================ + +/// Immutable reference wrapper +pub struct Ref[T] { + ptr: ref T, +} + +impl[T] Ref[T] { + /// Create from reference + pub fn new(ptr: ref T) -> Ref[T] { + Ref { ptr } + } + + /// Get the reference + pub fn get(self: Self) -> ref T { + self.ptr + } + + /// Clone the reference (references are Copy) + pub fn clone(self: Self) -> Ref[T] { + Ref { ptr: self.ptr } + } + + /// Map over reference + pub fn map[U](self: Self, f: fn(ref T) -> ref U) -> Ref[U] { + Ref::new(f(self.ptr)) + } +} + +/// Mutable reference wrapper +pub struct RefMut[T] { + ptr: mut T, +} + +impl[T] RefMut[T] { + /// Create from mutable reference + pub fn new(ptr: mut T) -> RefMut[T] { + RefMut { ptr } + } + + /// Get immutable reference + pub fn get(self: ref Self) -> ref T { + self.ptr + } + + /// Get mutable reference + pub fn get_mut(self: mut Self) -> mut T { + self.ptr + } + + /// Reborrow as immutable + pub fn reborrow(self: ref Self) -> Ref[T] { + Ref::new(self.ptr) + } + + /// Map over mutable reference + pub fn map_mut[U](self: Self, f: fn(mut T) -> mut U) -> RefMut[U] { + RefMut::new(f(self.ptr)) + } +} + +// ============================================================================ +// Smart Pointers +// ============================================================================ + +/// Box - heap-allocated value with unique ownership +pub struct Box[T] { + ptr: own T, +} + +impl[T] Box[T] { + /// Allocate value on heap + pub fn new(value: T) -> Box[T] { + Box { ptr: value } + } + + /// Get reference to contents + pub fn get(self: ref Self) -> ref T { + &self.ptr + } + + /// Get mutable reference + pub fn get_mut(self: mut Self) -> mut T { + &mut self.ptr + } + + /// Move value out of box + pub fn into_inner(self: Self) -> T { + self.ptr + } + + /// Leak the box, returning a static reference + pub fn leak(self: Self) -> ref T { + // Would leak memory intentionally + &self.ptr + } +} + +impl[T: Clone] Clone for Box[T] { + fn clone(self: ref Self) -> Self { + Box::new(self.ptr.clone()) + } +} + +/// Reference-counted pointer +pub struct Rc[T] { + ptr: ref T, + count: ref Int, +} + +impl[T] Rc[T] { + /// Create new Rc + pub fn new(value: T) -> Rc[T] { + // Would allocate with refcount + Rc { ptr: &value, count: &1 } + } + + /// Get reference to contents + pub fn get(self: ref Self) -> ref T { + self.ptr + } + + /// Get strong count + pub fn strong_count(self: ref Self) -> Int { + *self.count + } + + /// Check if this is the only reference + pub fn is_unique(self: ref Self) -> Bool { + *self.count == 1 + } + + /// Try to get mutable access (only if unique) + pub fn get_mut(self: mut Self) -> Option[mut T] { + if self.is_unique() { + // Would return mutable reference + None + } else { + None + } + } + + /// Make a mutable copy (clone-on-write) + pub fn make_mut(self: mut Self) -> mut T + where + T: Clone, + { + if !self.is_unique() { + // Clone the data + *self = Rc::new(self.ptr.clone()); + } + // Now we're unique, safe to return mut ref + self.get_mut().unwrap() + } + + /// Try to unwrap if unique + pub fn try_unwrap(self: Self) -> Result[T, Rc[T]] { + if self.is_unique() { + Ok(*self.ptr) // Would move out + } else { + Err(self) + } + } +} + +impl[T] Clone for Rc[T] { + fn clone(self: ref Self) -> Self { + // Increment refcount + Rc { ptr: self.ptr, count: self.count } + } +} + +/// Weak reference (doesn't prevent deallocation) +pub struct Weak[T] { + ptr: Option[ref T], +} + +impl[T] Weak[T] { + /// Create weak reference from Rc + pub fn new(rc: ref Rc[T]) -> Weak[T] { + Weak { ptr: Some(rc.ptr) } + } + + /// Create empty weak reference + pub fn empty() -> Weak[T] { + Weak { ptr: None } + } + + /// Try to upgrade to Rc + pub fn upgrade(self: ref Self) -> Option[Rc[T]] { + // Would check if still alive and increment count + None + } + + /// Check if still alive + pub fn is_alive(self: ref Self) -> Bool { + self.ptr.is_some() + } +} + +// ============================================================================ +// Borrowed Wrapper +// ============================================================================ + +/// Wrapper indicating borrowed data +pub struct Borrowed[T] { + data: ref T, + _lifetime: (), +} + +impl[T] Borrowed[T] { + pub fn new(data: ref T) -> Borrowed[T] { + Borrowed { data, _lifetime: () } + } + + pub fn get(self: Self) -> ref T { + self.data + } +} + +/// Convert owned to borrowed for function call +pub fn borrow_for[T, R](owned: ref T, f: fn(Borrowed[T]) -> R) -> R { + f(Borrowed::new(owned)) +} + +// ============================================================================ +// Cow (Clone-on-Write) +// ============================================================================ + +/// Clone-on-Write: either borrowed or owned +pub enum Cow[T] { + /// Borrowed data + Borrowed(ref T), + /// Owned data + Owned(T), +} + +impl[T: Clone] Cow[T] { + /// Get reference to data + pub fn get(self: ref Self) -> ref T { + match self { + Borrowed(r) => r, + Owned(v) => v, + } + } + + /// Get mutable reference (clones if borrowed) + pub fn to_mut(self: mut Self) -> mut T { + match self { + Borrowed(r) => { + *self = Owned(r.clone()); + match self { + Owned(v) => v, + _ => unreachable!(), + } + }, + Owned(v) => v, + } + } + + /// Convert to owned (clones if borrowed) + pub fn into_owned(self: Self) -> T { + match self { + Borrowed(r) => r.clone(), + Owned(v) => v, + } + } + + /// Check if borrowed + pub fn is_borrowed(self: ref Self) -> Bool { + match self { + Borrowed(_) => true, + Owned(_) => false, + } + } + + /// Check if owned + pub fn is_owned(self: ref Self) -> Bool { + match self { + Borrowed(_) => false, + Owned(_) => true, + } + } +} + +// ============================================================================ +// ManuallyDrop +// ============================================================================ + +/// Wrapper that prevents automatic drop +pub struct ManuallyDrop[T] { + value: T, +} + +impl[T] ManuallyDrop[T] { + /// Wrap value + pub fn new(value: T) -> ManuallyDrop[T] { + ManuallyDrop { value } + } + + /// Get reference + pub fn get(self: ref Self) -> ref T { + &self.value + } + + /// Get mutable reference + pub fn get_mut(self: mut Self) -> mut T { + &mut self.value + } + + /// Take the value out + pub fn into_inner(self: Self) -> T { + self.value + } + + /// Manually drop the value (unsafe - must only call once) + pub unsafe fn drop(self: mut Self) { + // Would call drop + } +} + +// ============================================================================ +// MaybeUninit +// ============================================================================ + +/// Wrapper for possibly uninitialized memory +pub struct MaybeUninit[T] { + // May contain uninitialized bytes + data: T, + initialized: Bool, +} + +impl[T] MaybeUninit[T] { + /// Create uninitialized + pub fn uninit() -> MaybeUninit[T] { + // Would allocate without initializing + MaybeUninit { data: panic!("uninit"), initialized: false } + } + + /// Create initialized + pub fn new(value: T) -> MaybeUninit[T] { + MaybeUninit { data: value, initialized: true } + } + + /// Write value (initializing) + pub fn write(self: mut Self, value: T) { + self.data = value; + self.initialized = true; + } + + /// Assume initialized and take value (unsafe) + pub unsafe fn assume_init(self: Self) -> T { + assert!(self.initialized, "value not initialized"); + self.data + } + + /// Check if initialized + pub fn is_initialized(self: ref Self) -> Bool { + self.initialized + } + + /// Get reference (unsafe - must be initialized) + pub unsafe fn get(self: ref Self) -> ref T { + assert!(self.initialized, "value not initialized"); + &self.data + } +} + +// ============================================================================ +// Take +// ============================================================================ + +/// Take trait - take a value, leaving default +pub trait Take: Default { + fn take(self: mut Self) -> Self { + let value = *self; + *self = Self::default(); + value + } +} + +impl Take for Int { + fn take(self: mut Self) -> Self { + let v = *self; + *self = 0; + v + } +} + +impl Take for String { + fn take(self: mut Self) -> Self { + let v = *self; + *self = ""; + v + } +} + +impl[T] Take for Option[T] { + fn take(self: mut Self) -> Self { + let v = *self; + *self = None; + v + } +} + +impl[T] Take for Vec[T] { + fn take(self: mut Self) -> Self { + let v = *self; + *self = Vec::new(); + v + } +} + +// ============================================================================ +// Replace +// ============================================================================ + +/// Replace a value, returning the old one +pub fn replace[T](dest: mut T, value: T) -> T { + let old = *dest; + *dest = value; + old +} + +/// Swap two values +pub fn swap[T](a: mut T, b: mut T) { + let tmp = *a; + *a = *b; + *b = tmp; +} + diff --git a/library/affinescript/refinements.afs b/library/affinescript/refinements.afs new file mode 100644 index 0000000..d4063bc --- /dev/null +++ b/library/affinescript/refinements.afs @@ -0,0 +1,447 @@ +// AffineScript Specific Library - Refinements +// Utilities for refinement types and dependent types + +module AffineScript.Refinements; + +use Common.Prelude::*; + +// ============================================================================ +// Refinement Type Constructors +// ============================================================================ + +/// Positive integer: { n: Int | n > 0 } +pub type Positive = refined Int where self > 0; + +/// Non-negative integer: { n: Int | n >= 0 } +pub type NonNegative = refined Int where self >= 0; + +/// Natural number (alias for NonNegative) +pub type Nat = NonNegative; + +/// Negative integer: { n: Int | n < 0 } +pub type Negative = refined Int where self < 0; + +/// Non-zero integer: { n: Int | n != 0 } +pub type NonZero = refined Int where self != 0; + +/// Bounded integer: { n: Int | lo <= n < hi } +pub type Bounded[const LO: Int, const HI: Int] = refined Int where self >= LO && self < HI; + +/// Percentage (0-100): { n: Int | 0 <= n <= 100 } +pub type Percentage = refined Int where self >= 0 && self <= 100; + +/// Unit interval: { x: Float | 0.0 <= x <= 1.0 } +pub type UnitFloat = refined Float where self >= 0.0 && self <= 1.0; + +/// Positive float: { x: Float | x > 0.0 } +pub type PositiveFloat = refined Float where self > 0.0; + +/// Non-empty string: { s: String | len(s) > 0 } +pub type NonEmptyString = refined String where len(self) > 0; + +/// Non-empty array: { arr: Array[T] | len(arr) > 0 } +pub type NonEmpty[T] = refined Array[T] where len(self) > 0; + +// ============================================================================ +// Smart Constructors +// ============================================================================ + +/// Try to create a Positive from Int +pub fn positive(n: Int) -> Option[Positive] { + if n > 0 { + Some(n as Positive) + } else { + None + } +} + +/// Try to create a NonNegative from Int +pub fn non_negative(n: Int) -> Option[NonNegative] { + if n >= 0 { + Some(n as NonNegative) + } else { + None + } +} + +/// Try to create a NonZero from Int +pub fn non_zero(n: Int) -> Option[NonZero] { + if n != 0 { + Some(n as NonZero) + } else { + None + } +} + +/// Try to create a bounded integer +pub fn bounded[const LO: Int, const HI: Int](n: Int) -> Option[Bounded[LO, HI]] { + if n >= LO && n < HI { + Some(n as Bounded[LO, HI]) + } else { + None + } +} + +/// Try to create a unit float +pub fn unit_float(x: Float) -> Option[UnitFloat] { + if x >= 0.0 && x <= 1.0 { + Some(x as UnitFloat) + } else { + None + } +} + +/// Try to create a non-empty string +pub fn non_empty_string(s: String) -> Option[NonEmptyString] { + if s.len() > 0 { + Some(s as NonEmptyString) + } else { + None + } +} + +/// Try to create a non-empty array +pub fn non_empty[T](arr: Array[T]) -> Option[NonEmpty[T]] { + if len(arr) > 0 { + Some(arr as NonEmpty[T]) + } else { + None + } +} + +// ============================================================================ +// Refinement Operations +// ============================================================================ + +/// Safe division (divisor must be non-zero) +pub fn safe_div(a: Int, b: NonZero) -> Int { + a / (b as Int) +} + +/// Safe float division +pub fn safe_divf(a: Float, b: NonZero) -> Float { + a / (b as Int as Float) +} + +/// Safe modulo +pub fn safe_mod(a: Int, b: NonZero) -> Int { + a % (b as Int) +} + +/// Safe array head (array must be non-empty) +pub fn head[T](arr: NonEmpty[T]) -> T { + arr[0] +} + +/// Safe array tail (array must be non-empty) +pub fn tail[T](arr: NonEmpty[T]) -> Array[T] { + arr[1..] +} + +/// Safe array last (array must be non-empty) +pub fn last[T](arr: NonEmpty[T]) -> T { + arr[len(arr) - 1] +} + +/// Safe array init (all but last) +pub fn init[T](arr: NonEmpty[T]) -> Array[T] { + arr[..len(arr) - 1] +} + +/// Absolute value with proof it's non-negative +pub fn abs_nat(n: Int) -> NonNegative { + if n >= 0 { n as NonNegative } + else { (-n) as NonNegative } +} + +/// Square (always non-negative) +pub fn square(n: Int) -> NonNegative { + (n * n) as NonNegative +} + +// ============================================================================ +// Length-Indexed Vectors +// ============================================================================ + +/// Vector with statically known length +pub struct Vec[const N: Nat, T] { + data: Array[T], +} + +impl[const N: Nat, T] Vec[N, T] { + /// Create from array (checked) + pub fn from_array(arr: Array[T]) -> Option[Vec[N, T]] { + if len(arr) == N { + Some(Vec { data: arr }) + } else { + None + } + } + + /// Get element at compile-time-checked index + pub fn get[const I: Nat](self: ref Self) -> T + where + I < N, // Compile-time bound check + { + self.data[I] + } + + /// Get length (known at compile time) + pub fn len(self: ref Self) -> Nat { + N + } + + /// Map over elements + pub fn map[U](self: Self, f: fn(T) -> U) -> Vec[N, U] { + Vec { data: self.data.map(f) } + } + + /// Zip two vectors of same length + pub fn zip[U](self: Self, other: Vec[N, U]) -> Vec[N, (T, U)] { + let mut result = []; + for i in 0..N { + result.push((self.data[i], other.data[i])); + } + Vec { data: result } + } + + /// Convert to regular array + pub fn to_array(self: Self) -> Array[T] { + self.data + } +} + +/// Empty vector +pub fn empty[T]() -> Vec[0, T] { + Vec { data: [] } +} + +/// Singleton vector +pub fn singleton[T](value: T) -> Vec[1, T] { + Vec { data: [value] } +} + +/// Replicate value n times +pub fn replicate[const N: Nat, T: Clone](value: T) -> Vec[N, T] { + let mut data = []; + for _ in 0..N { + data.push(value.clone()); + } + Vec { data } +} + +/// Append two vectors +pub fn append[const N: Nat, const M: Nat, T]( + a: Vec[N, T], + b: Vec[M, T] +) -> Vec[N + M, T] { + let mut data = a.data; + for item in b.data { + data.push(item); + } + Vec { data } +} + +/// Head of non-empty vector +pub fn vec_head[const N: Nat, T](v: Vec[N + 1, T]) -> T { + v.data[0] +} + +/// Tail of non-empty vector +pub fn vec_tail[const N: Nat, T](v: Vec[N + 1, T]) -> Vec[N, T] { + Vec { data: v.data[1..] } +} + +// ============================================================================ +// Matrix Types +// ============================================================================ + +/// Matrix with known dimensions +pub struct Matrix[const M: Nat, const N: Nat, T] { + data: Array[Array[T]], // Row-major +} + +impl[const M: Nat, const N: Nat, T] Matrix[M, N, T] { + /// Get element at checked indices + pub fn get[const I: Nat, const J: Nat](self: ref Self) -> T + where + I < M, + J < N, + { + self.data[I][J] + } + + /// Get row count + pub fn rows(self: ref Self) -> Nat { M } + + /// Get column count + pub fn cols(self: ref Self) -> Nat { N } + + /// Transpose matrix + pub fn transpose(self: Self) -> Matrix[N, M, T] { + let mut data = []; + for j in 0..N { + let mut row = []; + for i in 0..M { + row.push(self.data[i][j]); + } + data.push(row); + } + Matrix { data } + } +} + +/// Create zero matrix +pub fn zeros[const M: Nat, const N: Nat]() -> Matrix[M, N, Int] { + let mut data = []; + for _ in 0..M { + let mut row = []; + for _ in 0..N { + row.push(0); + } + data.push(row); + } + Matrix { data } +} + +/// Create identity matrix +pub fn identity[const N: Nat]() -> Matrix[N, N, Int] { + let mut data = []; + for i in 0..N { + let mut row = []; + for j in 0..N { + row.push(if i == j { 1 } else { 0 }); + } + data.push(row); + } + Matrix { data } +} + +/// Matrix multiplication +pub fn matmul[const M: Nat, const N: Nat, const P: Nat]( + a: Matrix[M, N, Int], + b: Matrix[N, P, Int] +) -> Matrix[M, P, Int] { + let mut data = []; + for i in 0..M { + let mut row = []; + for j in 0..P { + let mut sum = 0; + for k in 0..N { + sum += a.data[i][k] * b.data[k][j]; + } + row.push(sum); + } + data.push(row); + } + Matrix { data } +} + +// ============================================================================ +// Proof Objects +// ============================================================================ + +/// Proof that a < b +pub struct LessThan[const A: Int, const B: Int] { + _phantom: (), +} + +/// Proof that a <= b +pub struct LessEq[const A: Int, const B: Int] { + _phantom: (), +} + +/// Proof that a == b +pub struct Equal[const A: Int, const B: Int] { + _phantom: (), +} + +/// Create proof of less than (checked at compile time) +pub fn prove_lt[const A: Int, const B: Int]() -> LessThan[A, B] +where + A < B, +{ + LessThan { _phantom: () } +} + +/// Create proof of less equal +pub fn prove_le[const A: Int, const B: Int]() -> LessEq[A, B] +where + A <= B, +{ + LessEq { _phantom: () } +} + +/// Create proof of equality +pub fn prove_eq[const A: Int, const B: Int]() -> Equal[A, B] +where + A == B, +{ + Equal { _phantom: () } +} + +/// Transitivity of less than +pub fn trans_lt[const A: Int, const B: Int, const C: Int]( + _p1: LessThan[A, B], + _p2: LessThan[B, C] +) -> LessThan[A, C] { + LessThan { _phantom: () } +} + +// ============================================================================ +// Bounded Index +// ============================================================================ + +/// Index type that's always valid for an array +pub struct Index[const N: Nat] { + value: Nat, +} + +impl[const N: Nat] Index[N] { + /// Try to create an index + pub fn new(i: Int) -> Option[Index[N]] { + if i >= 0 && i < N { + Some(Index { value: i as Nat }) + } else { + None + } + } + + /// Get the value + pub fn get(self: Self) -> Nat { + self.value + } + + /// Index into array (always safe) + pub fn index[T](self: Self, arr: ref Vec[N, T]) -> ref T { + &arr.data[self.value] + } +} + +/// Iterate over all valid indices +pub fn indices[const N: Nat]() -> Array[Index[N]] { + let mut result = []; + for i in 0..N { + result.push(Index { value: i }); + } + result +} + +// ============================================================================ +// Assertion Helpers +// ============================================================================ + +/// Assert a condition at runtime, returning proof type +pub fn assert_positive(n: Int) -> Result[Positive, String] { + if n > 0 { + Ok(n as Positive) + } else { + Err(format("expected positive, got {}", [n.to_string()])) + } +} + +/// Unreachable with proof +pub fn absurd[T](proof: LessThan[1, 0]) -> T { + unreachable!("1 < 0 is impossible") +} + diff --git a/library/common/async.afs b/library/common/async.afs new file mode 100644 index 0000000..6cc78e1 --- /dev/null +++ b/library/common/async.afs @@ -0,0 +1,457 @@ +// AffineScript Common Library - Async +// Asynchronous programming primitives + +module Common.Async; + +use Common.Prelude::*; +use Common.IO::IOError; + +// ============================================================================ +// Future Types +// ============================================================================ + +/// A future representing an asynchronous computation +pub enum Future[T] { + /// Computation is still pending + Pending, + /// Computation completed successfully + Ready(T), + /// Computation failed + Failed(AsyncError), +} + +/// Async operation errors +pub enum AsyncError { + /// Task was cancelled + Cancelled, + /// Task timed out + Timeout, + /// Task panicked + Panicked(String), + /// IO error during async operation + IOError(IOError), + /// Join error when awaiting task + JoinError(String), + /// Channel was closed + ChannelClosed, + /// Send failed + SendError(String), + /// Receive failed + RecvError(String), +} + +impl AsyncError { + pub fn message(self: ref Self) -> String { + match self { + Cancelled => "task cancelled", + Timeout => "operation timed out", + Panicked(msg) => "task panicked: " + msg, + IOError(e) => "io error: " + e.message(), + JoinError(msg) => "join error: " + msg, + ChannelClosed => "channel closed", + SendError(msg) => "send error: " + msg, + RecvError(msg) => "receive error: " + msg, + } + } +} + +// ============================================================================ +// Async Effect +// ============================================================================ + +/// Async effect for asynchronous operations +pub effect Async { + /// Await a future + fn await[T](future: Future[T]) -> T; + + /// Spawn a new task + fn spawn[T](f: fn() -> T / Async) -> Future[T]; + + /// Yield control to the scheduler + fn yield_() -> (); + + /// Sleep for specified milliseconds + fn sleep(ms: Int) -> (); + + /// Get current task ID + fn task_id() -> Int; +} + +// ============================================================================ +// Future Implementation +// ============================================================================ + +impl[T] Future[T] { + /// Check if future is ready + pub fn is_ready(self: ref Self) -> Bool { + match self { + Ready(_) => true, + _ => false, + } + } + + /// Check if future is pending + pub fn is_pending(self: ref Self) -> Bool { + match self { + Pending => true, + _ => false, + } + } + + /// Check if future failed + pub fn is_failed(self: ref Self) -> Bool { + match self { + Failed(_) => true, + _ => false, + } + } + + /// Get value if ready + pub fn try_get(self: ref Self) -> Option[ref T] { + match self { + Ready(v) => Some(v), + _ => None, + } + } + + /// Get error if failed + pub fn error(self: ref Self) -> Option[ref AsyncError] { + match self { + Failed(e) => Some(e), + _ => None, + } + } + + /// Map the result value + pub fn map[U](self: Self, f: fn(T) -> U) -> Future[U] { + match self { + Pending => Pending, + Ready(v) => Ready(f(v)), + Failed(e) => Failed(e), + } + } + + /// Chain futures + pub fn and_then[U](self: Self, f: fn(T) -> Future[U]) -> Future[U] { + match self { + Pending => Pending, + Ready(v) => f(v), + Failed(e) => Failed(e), + } + } + + /// Provide a fallback on failure + pub fn or_else(self: Self, f: fn(AsyncError) -> Future[T]) -> Future[T] { + match self { + Failed(e) => f(e), + other => other, + } + } + + /// Unwrap with default on failure + pub fn unwrap_or(self: Self, default: T) -> T { + match self { + Ready(v) => v, + _ => default, + } + } +} + +// ============================================================================ +// Async Combinators +// ============================================================================ + +/// Join two futures, returning when both complete +pub fn join[A, B](a: Future[A], b: Future[B]) -> Future[(A, B)] / Async { + let a_val = await(a); + let b_val = await(b); + Ready((a_val, b_val)) +} + +/// Join three futures +pub fn join3[A, B, C](a: Future[A], b: Future[B], c: Future[C]) -> Future[(A, B, C)] / Async { + let a_val = await(a); + let b_val = await(b); + let c_val = await(c); + Ready((a_val, b_val, c_val)) +} + +/// Select the first future to complete +pub fn select[T](a: Future[T], b: Future[T]) -> Future[T] / Async { + // Implementation would poll both and return first ready + await(a) // Simplified +} + +/// Race multiple futures, returning the first to complete +pub fn race[T](futures: Array[Future[T]]) -> Future[T] / Async { + // Implementation would poll all and return first ready + if len(futures) > 0 { + await(futures[0]) + } else { + Failed(AsyncError::JoinError("empty race")) + } +} + +/// Wait for all futures to complete +pub fn join_all[T](futures: Array[Future[T]]) -> Future[Array[T]] / Async { + let mut results = []; + for future in futures { + results.push(await(future)); + } + Ready(results) +} + +/// Try to join all, collecting results and errors +pub fn try_join_all[T](futures: Array[Future[T]]) -> Future[Array[Result[T, AsyncError]]] / Async { + let mut results = []; + for future in futures { + match future { + Ready(v) => results.push(Ok(v)), + Failed(e) => results.push(Err(e)), + Pending => { + let v = await(future); + results.push(Ok(v)); + } + } + } + Ready(results) +} + +// ============================================================================ +// Timeout +// ============================================================================ + +/// Run a future with a timeout +pub fn timeout[T](ms: Int, future: Future[T]) -> Future[Result[T, AsyncError]] / Async { + // Implementation would race with sleep + match await(future) { + v => Ready(Ok(v)), + } +} + +/// Run a future with a timeout, returning None on timeout +pub fn timeout_opt[T](ms: Int, future: Future[T]) -> Future[Option[T]] / Async { + match timeout(ms, future) { + Ready(Ok(v)) => Ready(Some(v)), + Ready(Err(_)) => Ready(None), + Failed(e) => Failed(e), + Pending => Pending, + } +} + +// ============================================================================ +// Channels +// ============================================================================ + +/// A channel sender +pub struct Sender[T] { + // Implementation details hidden + id: Int, +} + +/// A channel receiver +pub struct Receiver[T] { + // Implementation details hidden + id: Int, +} + +/// Create a bounded channel +pub fn channel[T](capacity: Int) -> (Sender[T], Receiver[T]) { + let id = 0; // Would be unique ID + (Sender { id }, Receiver { id }) +} + +/// Create an unbounded channel +pub fn unbounded[T]() -> (Sender[T], Receiver[T]) { + channel(0) // 0 means unbounded +} + +impl[T] Sender[T] { + /// Send a value + pub fn send(self: ref Self, value: T) -> Result[(), AsyncError] / Async { + // Would actually send + Ok(()) + } + + /// Try to send without blocking + pub fn try_send(self: ref Self, value: T) -> Result[(), AsyncError] { + // Would try to send immediately + Ok(()) + } + + /// Check if receiver is still alive + pub fn is_closed(self: ref Self) -> Bool { + false + } +} + +impl[T] Receiver[T] { + /// Receive a value + pub fn recv(self: ref Self) -> Result[T, AsyncError] / Async { + // Would actually receive + Err(AsyncError::ChannelClosed) + } + + /// Try to receive without blocking + pub fn try_recv(self: ref Self) -> Option[T] { + None + } + + /// Create an iterator over received values + pub fn iter(self: Self) -> RecvIter[T] { + RecvIter { receiver: self } + } +} + +/// Iterator over channel values +pub struct RecvIter[T] { + receiver: Receiver[T], +} + +impl[T] Iterator for RecvIter[T] { + type Item = T; + + fn next(self: mut Self) -> Option[T] / Async { + match self.receiver.recv() { + Ok(v) => Some(v), + Err(_) => None, + } + } +} + +// ============================================================================ +// Mutex and Synchronization +// ============================================================================ + +/// Async mutex for protecting shared state +pub struct Mutex[T] { + // Implementation details hidden + value: T, + locked: Bool, +} + +impl[T] Mutex[T] { + /// Create a new mutex + pub fn new(value: T) -> Mutex[T] { + Mutex { value, locked: false } + } + + /// Lock the mutex + pub fn lock(self: ref Self) -> MutexGuard[T] / Async { + // Would wait until lock available + MutexGuard { mutex: self } + } + + /// Try to lock without blocking + pub fn try_lock(self: ref Self) -> Option[MutexGuard[T]] { + if !self.locked { + Some(MutexGuard { mutex: self }) + } else { + None + } + } +} + +/// Guard that releases mutex on drop +pub struct MutexGuard[T] { + mutex: ref Mutex[T], +} + +impl[T] MutexGuard[T] { + pub fn get(self: ref Self) -> ref T { + &self.mutex.value + } + + pub fn get_mut(self: mut Self) -> mut T { + &mut self.mutex.value + } +} + +/// Semaphore for limiting concurrent access +pub struct Semaphore { + permits: Int, + max_permits: Int, +} + +impl Semaphore { + pub fn new(permits: Int) -> Semaphore { + Semaphore { permits, max_permits: permits } + } + + pub fn acquire(self: mut Self) -> SemaphorePermit / Async { + // Would wait for permit + self.permits -= 1; + SemaphorePermit { semaphore: self } + } + + pub fn try_acquire(self: mut Self) -> Option[SemaphorePermit] { + if self.permits > 0 { + self.permits -= 1; + Some(SemaphorePermit { semaphore: self }) + } else { + None + } + } + + pub fn available_permits(self: ref Self) -> Int { + self.permits + } +} + +pub struct SemaphorePermit { + semaphore: mut Semaphore, +} + +// ============================================================================ +// Task Spawning Helpers +// ============================================================================ + +/// Spawn a task and return its handle +pub fn spawn[T](f: fn() -> T / Async) -> TaskHandle[T] / Async { + let future = Async::spawn(f); + TaskHandle { future, cancelled: false } +} + +/// Handle to a spawned task +pub struct TaskHandle[T] { + future: Future[T], + cancelled: Bool, +} + +impl[T] TaskHandle[T] { + /// Wait for task to complete + pub fn join(self: Self) -> Result[T, AsyncError] / Async { + if self.cancelled { + Err(AsyncError::Cancelled) + } else { + Ok(await(self.future)) + } + } + + /// Cancel the task + pub fn cancel(self: mut Self) { + self.cancelled = true; + } + + /// Check if task is finished + pub fn is_finished(self: ref Self) -> Bool { + self.future.is_ready() || self.future.is_failed() + } +} + +/// Spawn a detached task (result is ignored) +pub fn spawn_detached(f: fn() -> () / Async) / Async { + let _ = Async::spawn(f); +} + +// ============================================================================ +// Async Block Helper +// ============================================================================ + +/// Run an async block to completion (blocking) +pub fn block_on[T](f: fn() -> T / Async) -> T { + // Would run event loop until complete + // This is the entry point from sync to async + f() +} + diff --git a/library/common/collections.afs b/library/common/collections.afs new file mode 100644 index 0000000..fc58d69 --- /dev/null +++ b/library/common/collections.afs @@ -0,0 +1,531 @@ +// AffineScript Common Library - Collections +// Generic collection types and algorithms + +module Common.Collections; + +use Common.Prelude::*; + +// ============================================================================ +// Vec - Dynamic Array +// ============================================================================ + +/// Dynamic array with ownership semantics +pub struct Vec[T] { + data: own Array[T], + len: Int, + cap: Int, +} + +impl[T] Vec[T] { + /// Create empty vector + pub fn new() -> Vec[T] { + Vec { data: [], len: 0, cap: 0 } + } + + /// Create vector with capacity + pub fn with_capacity(cap: Int) -> Vec[T] { + Vec { data: [], len: 0, cap: cap } + } + + /// Create from array + pub fn from_array(arr: Array[T]) -> Vec[T] { + let n = len(arr); + Vec { data: arr, len: n, cap: n } + } + + /// Get length + pub fn len(self: ref Self) -> Int { + self.len + } + + /// Check if empty + pub fn is_empty(self: ref Self) -> Bool { + self.len == 0 + } + + /// Get capacity + pub fn capacity(self: ref Self) -> Int { + self.cap + } + + /// Push element to back + pub fn push(self: mut Self, value: T) { + if self.len >= self.cap { + self.grow(); + } + self.data[self.len] = value; + self.len += 1; + } + + /// Pop element from back + pub fn pop(self: mut Self) -> Option[T] { + if self.len == 0 { + None + } else { + self.len -= 1; + Some(self.data[self.len]) + } + } + + /// Get element by index + pub fn get(self: ref Self, index: Int) -> Option[ref T] { + if index >= 0 && index < self.len { + Some(&self.data[index]) + } else { + None + } + } + + /// Get mutable element by index + pub fn get_mut(self: mut Self, index: Int) -> Option[mut T] { + if index >= 0 && index < self.len { + Some(&mut self.data[index]) + } else { + None + } + } + + /// Get first element + pub fn first(self: ref Self) -> Option[ref T] { + self.get(0) + } + + /// Get last element + pub fn last(self: ref Self) -> Option[ref T] { + if self.len > 0 { + self.get(self.len - 1) + } else { + None + } + } + + /// Clear all elements + pub fn clear(self: mut Self) { + self.len = 0; + } + + /// Insert at index + pub fn insert(self: mut Self, index: Int, value: T) { + assert(index >= 0 && index <= self.len); + if self.len >= self.cap { + self.grow(); + } + // Shift elements right + let mut i = self.len; + while i > index { + self.data[i] = self.data[i - 1]; + i -= 1; + } + self.data[index] = value; + self.len += 1; + } + + /// Remove at index + pub fn remove(self: mut Self, index: Int) -> T { + assert(index >= 0 && index < self.len); + let value = self.data[index]; + // Shift elements left + let mut i = index; + while i < self.len - 1 { + self.data[i] = self.data[i + 1]; + i += 1; + } + self.len -= 1; + value + } + + /// Grow capacity + fn grow(self: mut Self) { + let new_cap = if self.cap == 0 { 4 } else { self.cap * 2 }; + // Would reallocate data array + self.cap = new_cap; + } + + /// Convert to iterator + pub fn iter(self: ref Self) -> VecIter[T] { + VecIter { vec: self, index: 0 } + } +} + +/// Vec iterator +pub struct VecIter[T] { + vec: ref Vec[T], + index: Int, +} + +impl[T] Iterator for VecIter[T] { + type Item = ref T; + + fn next(self: mut Self) -> Option[ref T] { + if self.index < self.vec.len { + let item = &self.vec.data[self.index]; + self.index += 1; + Some(item) + } else { + None + } + } +} + +// ============================================================================ +// HashMap +// ============================================================================ + +/// Hash map with separate chaining +pub struct HashMap[K: Hash + Eq, V] { + buckets: Array[Option[HashMapEntry[K, V]]], + len: Int, + cap: Int, +} + +struct HashMapEntry[K, V] { + key: K, + value: V, + next: Option[own HashMapEntry[K, V]], +} + +impl[K: Hash + Eq, V] HashMap[K, V] { + /// Create empty hash map + pub fn new() -> HashMap[K, V] { + HashMap::with_capacity(16) + } + + /// Create with capacity + pub fn with_capacity(cap: Int) -> HashMap[K, V] { + let buckets = []; // Initialize with None + HashMap { buckets: buckets, len: 0, cap: cap } + } + + /// Get length + pub fn len(self: ref Self) -> Int { + self.len + } + + /// Check if empty + pub fn is_empty(self: ref Self) -> Bool { + self.len == 0 + } + + /// Insert key-value pair + pub fn insert(self: mut Self, key: K, value: V) -> Option[V] { + let hash = key.hash(); + let index = hash % self.cap; + // Would implement bucket insertion + self.len += 1; + None + } + + /// Get value by key + pub fn get(self: ref Self, key: ref K) -> Option[ref V] { + let hash = key.hash(); + let index = hash % self.cap; + // Would search bucket chain + None + } + + /// Remove key + pub fn remove(self: mut Self, key: ref K) -> Option[V] { + let hash = key.hash(); + let index = hash % self.cap; + // Would remove from bucket chain + None + } + + /// Check if key exists + pub fn contains_key(self: ref Self, key: ref K) -> Bool { + self.get(key).is_some() + } + + /// Clear all entries + pub fn clear(self: mut Self) { + self.len = 0; + } +} + +// ============================================================================ +// HashSet +// ============================================================================ + +/// Hash set +pub struct HashSet[T: Hash + Eq] { + map: HashMap[T, ()], +} + +impl[T: Hash + Eq] HashSet[T] { + /// Create empty set + pub fn new() -> HashSet[T] { + HashSet { map: HashMap::new() } + } + + /// Get length + pub fn len(self: ref Self) -> Int { + self.map.len() + } + + /// Check if empty + pub fn is_empty(self: ref Self) -> Bool { + self.map.is_empty() + } + + /// Insert element + pub fn insert(self: mut Self, value: T) -> Bool { + self.map.insert(value, ()).is_none() + } + + /// Remove element + pub fn remove(self: mut Self, value: ref T) -> Bool { + self.map.remove(value).is_some() + } + + /// Check if element exists + pub fn contains(self: ref Self, value: ref T) -> Bool { + self.map.contains_key(value) + } + + /// Clear all elements + pub fn clear(self: mut Self) { + self.map.clear(); + } +} + +// ============================================================================ +// LinkedList +// ============================================================================ + +/// Doubly linked list +pub struct LinkedList[T] { + head: Option[own ListNode[T]], + tail: Option[mut ListNode[T]], + len: Int, +} + +struct ListNode[T] { + value: T, + prev: Option[mut ListNode[T]], + next: Option[own ListNode[T]], +} + +impl[T] LinkedList[T] { + /// Create empty list + pub fn new() -> LinkedList[T] { + LinkedList { head: None, tail: None, len: 0 } + } + + /// Get length + pub fn len(self: ref Self) -> Int { + self.len + } + + /// Check if empty + pub fn is_empty(self: ref Self) -> Bool { + self.len == 0 + } + + /// Push to front + pub fn push_front(self: mut Self, value: T) { + let node = ListNode { value: value, prev: None, next: self.head }; + self.head = Some(node); + if self.len == 0 { + self.tail = self.head.as_ref().map(\n -> &mut n); + } + self.len += 1; + } + + /// Push to back + pub fn push_back(self: mut Self, value: T) { + let node = ListNode { value: value, prev: self.tail, next: None }; + match self.tail { + Some(ref mut tail) => tail.next = Some(node), + None => self.head = Some(node), + } + self.len += 1; + } + + /// Pop from front + pub fn pop_front(self: mut Self) -> Option[T] { + match self.head { + Some(node) => { + self.head = node.next; + self.len -= 1; + if self.len == 0 { + self.tail = None; + } + Some(node.value) + }, + None => None, + } + } + + /// Pop from back + pub fn pop_back(self: mut Self) -> Option[T] { + // Would need to track tail properly + None + } + + /// Get front element + pub fn front(self: ref Self) -> Option[ref T] { + self.head.as_ref().map(\n -> &n.value) + } + + /// Get back element + pub fn back(self: ref Self) -> Option[ref T] { + self.tail.map(\n -> &n.value) + } +} + +// ============================================================================ +// BinaryHeap (Priority Queue) +// ============================================================================ + +/// Binary heap (max-heap by default) +pub struct BinaryHeap[T: Ord] { + data: Vec[T], +} + +impl[T: Ord] BinaryHeap[T] { + /// Create empty heap + pub fn new() -> BinaryHeap[T] { + BinaryHeap { data: Vec::new() } + } + + /// Get length + pub fn len(self: ref Self) -> Int { + self.data.len() + } + + /// Check if empty + pub fn is_empty(self: ref Self) -> Bool { + self.data.is_empty() + } + + /// Push element + pub fn push(self: mut Self, value: T) { + self.data.push(value); + self.sift_up(self.data.len() - 1); + } + + /// Pop max element + pub fn pop(self: mut Self) -> Option[T] { + if self.data.is_empty() { + None + } else { + let last_idx = self.data.len() - 1; + // Swap first and last + self.data.data[0] = self.data.data[last_idx]; + let max = self.data.pop(); + if !self.data.is_empty() { + self.sift_down(0); + } + max + } + } + + /// Peek max element + pub fn peek(self: ref Self) -> Option[ref T] { + self.data.first() + } + + fn sift_up(self: mut Self, index: Int) { + let mut i = index; + while i > 0 { + let parent = (i - 1) / 2; + if self.data.data[i].gt(&self.data.data[parent]) { + // Swap + let temp = self.data.data[i]; + self.data.data[i] = self.data.data[parent]; + self.data.data[parent] = temp; + i = parent; + } else { + break; + } + } + } + + fn sift_down(self: mut Self, index: Int) { + let mut i = index; + let n = self.data.len(); + while true { + let left = 2 * i + 1; + let right = 2 * i + 2; + let mut largest = i; + + if left < n && self.data.data[left].gt(&self.data.data[largest]) { + largest = left; + } + if right < n && self.data.data[right].gt(&self.data.data[largest]) { + largest = right; + } + + if largest != i { + // Swap + let temp = self.data.data[i]; + self.data.data[i] = self.data.data[largest]; + self.data.data[largest] = temp; + i = largest; + } else { + break; + } + } + } +} + +// ============================================================================ +// Deque +// ============================================================================ + +/// Double-ended queue +pub struct Deque[T] { + data: Vec[T], + front: Int, + back: Int, +} + +impl[T] Deque[T] { + /// Create empty deque + pub fn new() -> Deque[T] { + Deque { data: Vec::with_capacity(8), front: 0, back: 0 } + } + + /// Get length + pub fn len(self: ref Self) -> Int { + (self.back - self.front + self.data.capacity()) % self.data.capacity() + } + + /// Check if empty + pub fn is_empty(self: ref Self) -> Bool { + self.front == self.back + } + + /// Push to front + pub fn push_front(self: mut Self, value: T) { + self.front = (self.front - 1 + self.data.capacity()) % self.data.capacity(); + self.data.data[self.front] = value; + } + + /// Push to back + pub fn push_back(self: mut Self, value: T) { + self.data.data[self.back] = value; + self.back = (self.back + 1) % self.data.capacity(); + } + + /// Pop from front + pub fn pop_front(self: mut Self) -> Option[T] { + if self.is_empty() { + None + } else { + let value = self.data.data[self.front]; + self.front = (self.front + 1) % self.data.capacity(); + Some(value) + } + } + + /// Pop from back + pub fn pop_back(self: mut Self) -> Option[T] { + if self.is_empty() { + None + } else { + self.back = (self.back - 1 + self.data.capacity()) % self.data.capacity(); + Some(self.data.data[self.back]) + } + } +} diff --git a/library/common/io.afs b/library/common/io.afs new file mode 100644 index 0000000..1b360b3 --- /dev/null +++ b/library/common/io.afs @@ -0,0 +1,576 @@ +// AffineScript Common Library - IO +// Input/Output operations and effects + +module Common.IO; + +use Common.Prelude::*; + +// ============================================================================ +// IO Effect Definition +// ============================================================================ + +/// IO effect for side-effectful operations +pub effect IO { + /// Read a line from standard input + fn read_line() -> String; + + /// Write string to standard output + fn print(s: String) -> (); + + /// Write string with newline to standard output + fn println(s: String) -> (); + + /// Read entire file contents + fn read_file(path: String) -> Result[String, IOError]; + + /// Write string to file + fn write_file(path: String, content: String) -> Result[(), IOError]; + + /// Append string to file + fn append_file(path: String, content: String) -> Result[(), IOError]; + + /// Check if file exists + fn file_exists(path: String) -> Bool; + + /// Delete file + fn delete_file(path: String) -> Result[(), IOError]; + + /// Create directory + fn create_dir(path: String) -> Result[(), IOError]; + + /// List directory contents + fn list_dir(path: String) -> Result[Array[String], IOError]; + + /// Get current working directory + fn current_dir() -> Result[String, IOError]; + + /// Set current working directory + fn set_current_dir(path: String) -> Result[(), IOError]; + + /// Get environment variable + fn get_env(name: String) -> Option[String]; + + /// Set environment variable + fn set_env(name: String, value: String) -> (); + + /// Get command line arguments + fn args() -> Array[String]; + + /// Exit program with status code + fn exit(code: Int) -> Never; +} + +// ============================================================================ +// IO Error Types +// ============================================================================ + +/// IO error kinds +pub enum IOError { + /// File or directory not found + NotFound(String), + /// Permission denied + PermissionDenied(String), + /// File already exists + AlreadyExists(String), + /// Invalid input data + InvalidInput(String), + /// Invalid data encoding + InvalidData(String), + /// Connection refused + ConnectionRefused(String), + /// Connection reset + ConnectionReset(String), + /// Connection aborted + ConnectionAborted(String), + /// Not connected + NotConnected(String), + /// Address in use + AddrInUse(String), + /// Address not available + AddrNotAvailable(String), + /// Broken pipe + BrokenPipe(String), + /// Would block (non-blocking IO) + WouldBlock(String), + /// Timed out + TimedOut(String), + /// Write zero bytes + WriteZero(String), + /// Interrupted + Interrupted(String), + /// Unexpected end of file + UnexpectedEof(String), + /// Other error + Other(String), +} + +impl IOError { + /// Get error message + pub fn message(self: ref Self) -> String { + match self { + NotFound(msg) => "not found: " + msg, + PermissionDenied(msg) => "permission denied: " + msg, + AlreadyExists(msg) => "already exists: " + msg, + InvalidInput(msg) => "invalid input: " + msg, + InvalidData(msg) => "invalid data: " + msg, + ConnectionRefused(msg) => "connection refused: " + msg, + ConnectionReset(msg) => "connection reset: " + msg, + ConnectionAborted(msg) => "connection aborted: " + msg, + NotConnected(msg) => "not connected: " + msg, + AddrInUse(msg) => "address in use: " + msg, + AddrNotAvailable(msg) => "address not available: " + msg, + BrokenPipe(msg) => "broken pipe: " + msg, + WouldBlock(msg) => "would block: " + msg, + TimedOut(msg) => "timed out: " + msg, + WriteZero(msg) => "write zero: " + msg, + Interrupted(msg) => "interrupted: " + msg, + UnexpectedEof(msg) => "unexpected eof: " + msg, + Other(msg) => msg, + } + } +} + +// ============================================================================ +// Reader/Writer Traits +// ============================================================================ + +/// Read trait for reading bytes +pub trait Read { + /// Read into buffer, return bytes read + fn read(self: mut Self, buf: mut Array[u8]) -> Result[Int, IOError]; + + /// Read exact number of bytes + fn read_exact(self: mut Self, buf: mut Array[u8]) -> Result[(), IOError] { + let mut offset = 0; + while offset < len(buf) { + match self.read(&mut buf[offset..]) { + Ok(0) => return Err(IOError::UnexpectedEof("read_exact")), + Ok(n) => offset += n, + Err(e) => return Err(e), + } + } + Ok(()) + } + + /// Read all remaining bytes + fn read_to_end(self: mut Self, buf: mut Vec[u8]) -> Result[Int, IOError] { + let mut total = 0; + let chunk = [0u8; 1024]; + while true { + match self.read(&mut chunk) { + Ok(0) => break, + Ok(n) => { + for i in 0..n { + buf.push(chunk[i]); + } + total += n; + }, + Err(e) => return Err(e), + } + } + Ok(total) + } + + /// Read all remaining bytes as string + fn read_to_string(self: mut Self, buf: mut String) -> Result[Int, IOError] { + let bytes = Vec::new(); + let n = self.read_to_end(&mut bytes)?; + // Convert bytes to string + Ok(n) + } +} + +/// Write trait for writing bytes +pub trait Write { + /// Write buffer, return bytes written + fn write(self: mut Self, buf: ref Array[u8]) -> Result[Int, IOError]; + + /// Flush output buffer + fn flush(self: mut Self) -> Result[(), IOError]; + + /// Write all bytes + fn write_all(self: mut Self, buf: ref Array[u8]) -> Result[(), IOError] { + let mut offset = 0; + while offset < len(buf) { + match self.write(&buf[offset..]) { + Ok(0) => return Err(IOError::WriteZero("write_all")), + Ok(n) => offset += n, + Err(e) => return Err(e), + } + } + Ok(()) + } + + /// Write string + fn write_str(self: mut Self, s: ref String) -> Result[(), IOError] { + // Convert string to bytes and write + Ok(()) + } +} + +/// BufRead trait for buffered reading +pub trait BufRead: Read { + /// Fill internal buffer + fn fill_buf(self: mut Self) -> Result[ref Array[u8], IOError]; + + /// Mark bytes as consumed + fn consume(self: mut Self, amt: Int); + + /// Read until delimiter + fn read_until(self: mut Self, delim: u8, buf: mut Vec[u8]) -> Result[Int, IOError] { + let mut total = 0; + while true { + let available = self.fill_buf()?; + if len(available) == 0 { + break; + } + // Find delimiter + let mut found = false; + for i in 0..len(available) { + buf.push(available[i]); + total += 1; + if available[i] == delim { + self.consume(i + 1); + found = true; + break; + } + } + if found { + break; + } + self.consume(len(available)); + } + Ok(total) + } + + /// Read line + fn read_line(self: mut Self, buf: mut String) -> Result[Int, IOError] { + let bytes = Vec::new(); + let n = self.read_until(b'\n', &mut bytes)?; + // Convert to string + Ok(n) + } + + /// Create lines iterator + fn lines(self: Self) -> Lines[Self] { + Lines { reader: self } + } +} + +/// Lines iterator +pub struct Lines[R: BufRead] { + reader: R, +} + +impl[R: BufRead] Iterator for Lines[R] { + type Item = Result[String, IOError]; + + fn next(self: mut Self) -> Option[Result[String, IOError]] { + let mut line = String::new(); + match self.reader.read_line(&mut line) { + Ok(0) => None, + Ok(_) => { + // Remove trailing newline + if line.ends_with("\n") { + line = line[..len(line)-1]; + } + Some(Ok(line)) + }, + Err(e) => Some(Err(e)), + } + } +} + +// ============================================================================ +// File Types +// ============================================================================ + +/// File handle +pub struct File { + // Implementation details hidden + handle: Int, +} + +impl File { + /// Open file for reading + pub fn open(path: String) -> Result[File, IOError] / IO { + // Would use IO effect to open + Ok(File { handle: 0 }) + } + + /// Create new file for writing + pub fn create(path: String) -> Result[File, IOError] / IO { + Ok(File { handle: 0 }) + } + + /// Open with options + pub fn open_with(path: String, options: OpenOptions) -> Result[File, IOError] / IO { + Ok(File { handle: 0 }) + } + + /// Sync all data to disk + pub fn sync_all(self: ref Self) -> Result[(), IOError] / IO { + Ok(()) + } + + /// Sync data (not metadata) to disk + pub fn sync_data(self: ref Self) -> Result[(), IOError] / IO { + Ok(()) + } + + /// Get file metadata + pub fn metadata(self: ref Self) -> Result[Metadata, IOError] / IO { + Ok(Metadata { size: 0, is_dir: false, is_file: true }) + } +} + +impl Read for File { + fn read(self: mut Self, buf: mut Array[u8]) -> Result[Int, IOError] { + // Would perform actual read + Ok(0) + } +} + +impl Write for File { + fn write(self: mut Self, buf: ref Array[u8]) -> Result[Int, IOError] { + Ok(len(buf)) + } + + fn flush(self: mut Self) -> Result[(), IOError] { + Ok(()) + } +} + +/// File open options +pub struct OpenOptions { + read: Bool, + write: Bool, + append: Bool, + truncate: Bool, + create: Bool, + create_new: Bool, +} + +impl OpenOptions { + pub fn new() -> OpenOptions { + OpenOptions { + read: false, + write: false, + append: false, + truncate: false, + create: false, + create_new: false, + } + } + + pub fn read(self: mut Self, read: Bool) -> mut Self { + self.read = read; + self + } + + pub fn write(self: mut Self, write: Bool) -> mut Self { + self.write = write; + self + } + + pub fn append(self: mut Self, append: Bool) -> mut Self { + self.append = append; + self + } + + pub fn truncate(self: mut Self, truncate: Bool) -> mut Self { + self.truncate = truncate; + self + } + + pub fn create(self: mut Self, create: Bool) -> mut Self { + self.create = create; + self + } + + pub fn create_new(self: mut Self, create_new: Bool) -> mut Self { + self.create_new = create_new; + self + } + + pub fn open(self: ref Self, path: String) -> Result[File, IOError] / IO { + File::open_with(path, *self) + } +} + +/// File metadata +pub struct Metadata { + size: Int, + is_dir: Bool, + is_file: Bool, +} + +impl Metadata { + pub fn len(self: ref Self) -> Int { self.size } + pub fn is_dir(self: ref Self) -> Bool { self.is_dir } + pub fn is_file(self: ref Self) -> Bool { self.is_file } +} + +// ============================================================================ +// Buffered IO +// ============================================================================ + +/// Buffered reader +pub struct BufReader[R: Read] { + inner: R, + buf: Array[u8], + pos: Int, + cap: Int, +} + +impl[R: Read] BufReader[R] { + pub fn new(inner: R) -> BufReader[R] { + BufReader::with_capacity(8192, inner) + } + + pub fn with_capacity(cap: Int, inner: R) -> BufReader[R] { + BufReader { + inner: inner, + buf: [0u8; cap], + pos: 0, + cap: 0, + } + } + + pub fn into_inner(self: Self) -> R { + self.inner + } +} + +impl[R: Read] Read for BufReader[R] { + fn read(self: mut Self, buf: mut Array[u8]) -> Result[Int, IOError] { + // Fill buffer if needed, then copy + if self.pos >= self.cap { + self.cap = self.inner.read(&mut self.buf)?; + self.pos = 0; + } + let available = self.cap - self.pos; + let to_copy = if len(buf) < available { len(buf) } else { available }; + for i in 0..to_copy { + buf[i] = self.buf[self.pos + i]; + } + self.pos += to_copy; + Ok(to_copy) + } +} + +impl[R: Read] BufRead for BufReader[R] { + fn fill_buf(self: mut Self) -> Result[ref Array[u8], IOError] { + if self.pos >= self.cap { + self.cap = self.inner.read(&mut self.buf)?; + self.pos = 0; + } + Ok(&self.buf[self.pos..self.cap]) + } + + fn consume(self: mut Self, amt: Int) { + self.pos += amt; + } +} + +/// Buffered writer +pub struct BufWriter[W: Write] { + inner: W, + buf: Vec[u8], +} + +impl[W: Write] BufWriter[W] { + pub fn new(inner: W) -> BufWriter[W] { + BufWriter::with_capacity(8192, inner) + } + + pub fn with_capacity(cap: Int, inner: W) -> BufWriter[W] { + BufWriter { + inner: inner, + buf: Vec::with_capacity(cap), + } + } + + pub fn into_inner(self: mut Self) -> Result[W, IOError] { + self.flush()?; + Ok(self.inner) + } +} + +impl[W: Write] Write for BufWriter[W] { + fn write(self: mut Self, buf: ref Array[u8]) -> Result[Int, IOError] { + if self.buf.len() + len(buf) > self.buf.capacity() { + self.flush()?; + } + if len(buf) >= self.buf.capacity() { + self.inner.write(buf) + } else { + for b in buf { + self.buf.push(*b); + } + Ok(len(buf)) + } + } + + fn flush(self: mut Self) -> Result[(), IOError] { + if !self.buf.is_empty() { + self.inner.write_all(&self.buf.data)?; + self.buf.clear(); + } + self.inner.flush() + } +} + +// ============================================================================ +// Standard Streams +// ============================================================================ + +/// Standard input +pub fn stdin() -> Stdin { + Stdin { } +} + +pub struct Stdin { } + +impl Read for Stdin { + fn read(self: mut Self, buf: mut Array[u8]) -> Result[Int, IOError] { + // Would read from stdin + Ok(0) + } +} + +/// Standard output +pub fn stdout() -> Stdout { + Stdout { } +} + +pub struct Stdout { } + +impl Write for Stdout { + fn write(self: mut Self, buf: ref Array[u8]) -> Result[Int, IOError] { + // Would write to stdout + Ok(len(buf)) + } + + fn flush(self: mut Self) -> Result[(), IOError] { + Ok(()) + } +} + +/// Standard error +pub fn stderr() -> Stderr { + Stderr { } +} + +pub struct Stderr { } + +impl Write for Stderr { + fn write(self: mut Self, buf: ref Array[u8]) -> Result[Int, IOError] { + // Would write to stderr + Ok(len(buf)) + } + + fn flush(self: mut Self) -> Result[(), IOError] { + Ok(()) + } +} diff --git a/library/common/math.afs b/library/common/math.afs new file mode 100644 index 0000000..ebe8dae --- /dev/null +++ b/library/common/math.afs @@ -0,0 +1,621 @@ +// AffineScript Common Library - Math +// Mathematical functions and constants + +module Common.Math; + +use Common.Prelude::*; + +// ============================================================================ +// Constants +// ============================================================================ + +/// Pi (π) +pub const PI: Float = 3.141592653589793; + +/// Tau (2π) +pub const TAU: Float = 6.283185307179586; + +/// Euler's number (e) +pub const E: Float = 2.718281828459045; + +/// Golden ratio (φ) +pub const PHI: Float = 1.618033988749895; + +/// Square root of 2 +pub const SQRT2: Float = 1.4142135623730951; + +/// Square root of 3 +pub const SQRT3: Float = 1.7320508075688772; + +/// Natural log of 2 +pub const LN2: Float = 0.6931471805599453; + +/// Natural log of 10 +pub const LN10: Float = 2.302585092994046; + +/// Positive infinity +pub const INFINITY: Float = 1.0 / 0.0; + +/// Negative infinity +pub const NEG_INFINITY: Float = -1.0 / 0.0; + +/// Not a number +pub const NAN: Float = 0.0 / 0.0; + +// ============================================================================ +// Basic Functions +// ============================================================================ + +/// Absolute value (integer) +pub fn abs(x: Int) -> Int { + if x < 0 { -x } else { x } +} + +/// Absolute value (float) +pub fn fabs(x: Float) -> Float { + if x < 0.0 { -x } else { x } +} + +/// Sign of integer: -1, 0, or 1 +pub fn signum(x: Int) -> Int { + if x < 0 { -1 } + else if x > 0 { 1 } + else { 0 } +} + +/// Sign of float +pub fn fsignum(x: Float) -> Float { + if x < 0.0 { -1.0 } + else if x > 0.0 { 1.0 } + else { 0.0 } +} + +/// Maximum of two integers +pub fn max(a: Int, b: Int) -> Int { + if a >= b { a } else { b } +} + +/// Minimum of two integers +pub fn min(a: Int, b: Int) -> Int { + if a <= b { a } else { b } +} + +/// Maximum of two floats +pub fn fmax(a: Float, b: Float) -> Float { + if a >= b { a } else { b } +} + +/// Minimum of two floats +pub fn fmin(a: Float, b: Float) -> Float { + if a <= b { a } else { b } +} + +/// Clamp value to range +pub fn clamp(x: Int, lo: Int, hi: Int) -> Int { + if x < lo { lo } + else if x > hi { hi } + else { x } +} + +/// Clamp float to range +pub fn fclamp(x: Float, lo: Float, hi: Float) -> Float { + if x < lo { lo } + else if x > hi { hi } + else { x } +} + +// ============================================================================ +// Rounding Functions +// ============================================================================ + +/// Floor (round toward negative infinity) +pub fn floor(x: Float) -> Float { + // Native implementation + x +} + +/// Ceiling (round toward positive infinity) +pub fn ceil(x: Float) -> Float { + // Native implementation + x +} + +/// Round to nearest integer (ties to even) +pub fn round(x: Float) -> Float { + // Native implementation + x +} + +/// Truncate toward zero +pub fn trunc(x: Float) -> Float { + // Native implementation + x +} + +/// Fractional part +pub fn fract(x: Float) -> Float { + x - trunc(x) +} + +// ============================================================================ +// Power and Exponential Functions +// ============================================================================ + +/// Square root +pub fn sqrt(x: Float) -> Float { + // Native implementation + x +} + +/// Cube root +pub fn cbrt(x: Float) -> Float { + // Native implementation + x +} + +/// Power function +pub fn pow(base: Float, exp: Float) -> Float { + // Native implementation + base +} + +/// Integer power +pub fn powi(base: Float, exp: Int) -> Float { + if exp == 0 { 1.0 } + else if exp < 0 { 1.0 / powi(base, -exp) } + else if exp % 2 == 0 { + let half = powi(base, exp / 2); + half * half + } else { + base * powi(base, exp - 1) + } +} + +/// Integer power (integer base) +pub fn ipow(base: Int, exp: Int) -> Int { + if exp == 0 { 1 } + else if exp < 0 { 0 } // Integer division truncates + else if exp % 2 == 0 { + let half = ipow(base, exp / 2); + half * half + } else { + base * ipow(base, exp - 1) + } +} + +/// Exponential function (e^x) +pub fn exp(x: Float) -> Float { + // Native implementation + x +} + +/// Exponential minus one (e^x - 1), accurate for small x +pub fn expm1(x: Float) -> Float { + // Native implementation + x +} + +/// 2^x +pub fn exp2(x: Float) -> Float { + pow(2.0, x) +} + +// ============================================================================ +// Logarithm Functions +// ============================================================================ + +/// Natural logarithm +pub fn ln(x: Float) -> Float { + // Native implementation + x +} + +/// Natural log of (1 + x), accurate for small x +pub fn ln1p(x: Float) -> Float { + // Native implementation + x +} + +/// Base-2 logarithm +pub fn log2(x: Float) -> Float { + ln(x) / LN2 +} + +/// Base-10 logarithm +pub fn log10(x: Float) -> Float { + ln(x) / LN10 +} + +/// Arbitrary base logarithm +pub fn log(x: Float, base: Float) -> Float { + ln(x) / ln(base) +} + +// ============================================================================ +// Trigonometric Functions +// ============================================================================ + +/// Sine +pub fn sin(x: Float) -> Float { + // Native implementation + x +} + +/// Cosine +pub fn cos(x: Float) -> Float { + // Native implementation + x +} + +/// Tangent +pub fn tan(x: Float) -> Float { + sin(x) / cos(x) +} + +/// Arcsine +pub fn asin(x: Float) -> Float { + // Native implementation + x +} + +/// Arccosine +pub fn acos(x: Float) -> Float { + // Native implementation + x +} + +/// Arctangent +pub fn atan(x: Float) -> Float { + // Native implementation + x +} + +/// Two-argument arctangent +pub fn atan2(y: Float, x: Float) -> Float { + // Native implementation + y +} + +/// Convert degrees to radians +pub fn to_radians(degrees: Float) -> Float { + degrees * PI / 180.0 +} + +/// Convert radians to degrees +pub fn to_degrees(radians: Float) -> Float { + radians * 180.0 / PI +} + +// ============================================================================ +// Hyperbolic Functions +// ============================================================================ + +/// Hyperbolic sine +pub fn sinh(x: Float) -> Float { + (exp(x) - exp(-x)) / 2.0 +} + +/// Hyperbolic cosine +pub fn cosh(x: Float) -> Float { + (exp(x) + exp(-x)) / 2.0 +} + +/// Hyperbolic tangent +pub fn tanh(x: Float) -> Float { + sinh(x) / cosh(x) +} + +/// Inverse hyperbolic sine +pub fn asinh(x: Float) -> Float { + ln(x + sqrt(x * x + 1.0)) +} + +/// Inverse hyperbolic cosine +pub fn acosh(x: Float) -> Float { + ln(x + sqrt(x * x - 1.0)) +} + +/// Inverse hyperbolic tangent +pub fn atanh(x: Float) -> Float { + 0.5 * ln((1.0 + x) / (1.0 - x)) +} + +// ============================================================================ +// Integer Arithmetic +// ============================================================================ + +/// Greatest common divisor +pub fn gcd(a: Int, b: Int) -> Int { + let a = abs(a); + let b = abs(b); + if b == 0 { a } + else { gcd(b, a % b) } +} + +/// Least common multiple +pub fn lcm(a: Int, b: Int) -> Int { + if a == 0 || b == 0 { 0 } + else { abs(a) / gcd(a, b) * abs(b) } +} + +/// Integer division with floor rounding +pub fn div_floor(a: Int, b: Int) -> Int { + let q = a / b; + let r = a % b; + if r != 0 && ((a < 0) != (b < 0)) { + q - 1 + } else { + q + } +} + +/// Integer division with ceiling rounding +pub fn div_ceil(a: Int, b: Int) -> Int { + let q = a / b; + let r = a % b; + if r != 0 && ((a > 0) == (b > 0)) { + q + 1 + } else { + q + } +} + +/// Euclidean modulo (always non-negative) +pub fn mod_euc(a: Int, b: Int) -> Int { + let r = a % b; + if r < 0 { + if b < 0 { r - b } else { r + b } + } else { + r + } +} + +/// Factorial +pub fn factorial(n: Int) -> Int { + if n <= 1 { 1 } + else { n * factorial(n - 1) } +} + +/// Binomial coefficient (n choose k) +pub fn binomial(n: Int, k: Int) -> Int { + if k < 0 || k > n { 0 } + else if k == 0 || k == n { 1 } + else { + let k = if k > n - k { n - k } else { k }; + let mut result = 1; + for i in 0..k { + result = result * (n - i) / (i + 1); + } + result + } +} + +/// Check if number is prime +pub fn is_prime(n: Int) -> Bool { + if n < 2 { false } + else if n == 2 { true } + else if n % 2 == 0 { false } + else { + let sqrt_n = sqrt(n as Float) as Int; + let mut i = 3; + while i <= sqrt_n { + if n % i == 0 { return false; } + i += 2; + } + true + } +} + +/// Next prime >= n +pub fn next_prime(n: Int) -> Int { + let mut candidate = if n < 2 { 2 } else { n }; + while !is_prime(candidate) { + candidate += 1; + } + candidate +} + +// ============================================================================ +// Float Utilities +// ============================================================================ + +/// Check if float is NaN +pub fn is_nan(x: Float) -> Bool { + x != x +} + +/// Check if float is infinite +pub fn is_infinite(x: Float) -> Bool { + x == INFINITY || x == NEG_INFINITY +} + +/// Check if float is finite +pub fn is_finite(x: Float) -> Bool { + !is_nan(x) && !is_infinite(x) +} + +/// Check if float is normal (not zero, subnormal, infinite, or NaN) +pub fn is_normal(x: Float) -> Bool { + is_finite(x) && x != 0.0 +} + +/// Fused multiply-add: a * b + c +pub fn fma(a: Float, b: Float, c: Float) -> Float { + // Native implementation would be more accurate + a * b + c +} + +/// Copy sign of y to x +pub fn copysign(x: Float, y: Float) -> Float { + let mag = fabs(x); + if y < 0.0 { -mag } else { mag } +} + +/// Hypotenuse: sqrt(x^2 + y^2) +pub fn hypot(x: Float, y: Float) -> Float { + sqrt(x * x + y * y) +} + +// ============================================================================ +// Interpolation +// ============================================================================ + +/// Linear interpolation +pub fn lerp(a: Float, b: Float, t: Float) -> Float { + a + (b - a) * t +} + +/// Inverse linear interpolation +pub fn inv_lerp(a: Float, b: Float, v: Float) -> Float { + (v - a) / (b - a) +} + +/// Remap value from one range to another +pub fn remap(v: Float, in_min: Float, in_max: Float, out_min: Float, out_max: Float) -> Float { + let t = inv_lerp(in_min, in_max, v); + lerp(out_min, out_max, t) +} + +/// Smooth step interpolation (Hermite) +pub fn smoothstep(edge0: Float, edge1: Float, x: Float) -> Float { + let t = fclamp((x - edge0) / (edge1 - edge0), 0.0, 1.0); + t * t * (3.0 - 2.0 * t) +} + +/// Smoother step (Ken Perlin's improvement) +pub fn smootherstep(edge0: Float, edge1: Float, x: Float) -> Float { + let t = fclamp((x - edge0) / (edge1 - edge0), 0.0, 1.0); + t * t * t * (t * (t * 6.0 - 15.0) + 10.0) +} + +// ============================================================================ +// Statistics +// ============================================================================ + +/// Sum of array +pub fn sum(arr: Array[Int]) -> Int { + let mut total = 0; + for x in arr { + total += x; + } + total +} + +/// Sum of float array +pub fn fsum(arr: Array[Float]) -> Float { + let mut total = 0.0; + for x in arr { + total += x; + } + total +} + +/// Mean of float array +pub fn mean(arr: Array[Float]) -> Float { + if len(arr) == 0 { 0.0 } + else { fsum(arr) / (len(arr) as Float) } +} + +/// Variance of float array +pub fn variance(arr: Array[Float]) -> Float { + if len(arr) == 0 { 0.0 } + else { + let m = mean(arr); + let mut sum_sq = 0.0; + for x in arr { + let diff = x - m; + sum_sq += diff * diff; + } + sum_sq / (len(arr) as Float) + } +} + +/// Standard deviation +pub fn std_dev(arr: Array[Float]) -> Float { + sqrt(variance(arr)) +} + +/// Maximum of array +pub fn max_of(arr: Array[Int]) -> Option[Int] { + if len(arr) == 0 { None } + else { + let mut m = arr[0]; + for x in arr { + if x > m { m = x; } + } + Some(m) + } +} + +/// Minimum of array +pub fn min_of(arr: Array[Int]) -> Option[Int] { + if len(arr) == 0 { None } + else { + let mut m = arr[0]; + for x in arr { + if x < m { m = x; } + } + Some(m) + } +} + +// ============================================================================ +// Random Numbers (Simple LCG) +// ============================================================================ + +/// Random number generator state +pub struct Rng { + state: u64, +} + +impl Rng { + /// Create new RNG with seed + pub fn new(seed: u64) -> Rng { + Rng { state: seed } + } + + /// Generate next u64 + pub fn next_u64(self: mut Self) -> u64 { + // LCG parameters from Numerical Recipes + self.state = self.state * 6364136223846793005 + 1442695040888963407; + self.state + } + + /// Generate random int in range [0, n) + pub fn next_int(self: mut Self, n: Int) -> Int { + (self.next_u64() % (n as u64)) as Int + } + + /// Generate random float in [0, 1) + pub fn next_float(self: mut Self) -> Float { + (self.next_u64() as Float) / (18446744073709551616.0) + } + + /// Generate random float in [lo, hi) + pub fn next_float_range(self: mut Self, lo: Float, hi: Float) -> Float { + lo + self.next_float() * (hi - lo) + } + + /// Generate random bool + pub fn next_bool(self: mut Self) -> Bool { + self.next_u64() % 2 == 0 + } + + /// Shuffle array in place + pub fn shuffle[T](self: mut Self, arr: mut Array[T]) { + let n = len(arr); + for i in (1..n).rev() { + let j = self.next_int(i + 1); + let tmp = arr[i]; + arr[i] = arr[j]; + arr[j] = tmp; + } + } + + /// Choose random element + pub fn choose[T](self: mut Self, arr: ref Array[T]) -> Option[ref T] { + if len(arr) == 0 { None } + else { Some(&arr[self.next_int(len(arr))]) } + } +} + diff --git a/library/common/mod.afs b/library/common/mod.afs new file mode 100644 index 0000000..a2817e2 --- /dev/null +++ b/library/common/mod.afs @@ -0,0 +1,19 @@ +// AffineScript Common Library +// Re-exports all common modules + +module Common; + +pub use Common.Prelude::*; +pub use Common.Collections::*; +pub use Common.IO::*; +pub use Common.Async::*; +pub use Common.String::*; +pub use Common.Math::*; +pub use Common.Time::*; +pub use Common.Sync::*; + +/// Library version +pub const VERSION: String = "0.1.0"; + +/// Library name +pub const NAME: String = "common"; diff --git a/library/common/prelude.afs b/library/common/prelude.afs new file mode 100644 index 0000000..630fcce --- /dev/null +++ b/library/common/prelude.afs @@ -0,0 +1,378 @@ +// AffineScript Common Library - Prelude +// This module is automatically imported into every AffineScript program + +module Common.Prelude; + +// ============================================================================ +// Core Types +// ============================================================================ + +/// Unit type - the type with only one value +pub type Unit = (); + +/// Never type - the bottom type with no values (for diverging functions) +pub type Never = !; + +/// Option type for nullable values +pub enum Option[T] { + /// No value present + None, + /// A value is present + Some(T), +} + +/// Result type for fallible operations +pub enum Result[T, E] { + /// Operation succeeded with value + Ok(T), + /// Operation failed with error + Err(E), +} + +/// Ordering result for comparisons +pub enum Ordering { + Less, + Equal, + Greater, +} + +/// Either type for sum types +pub enum Either[L, R] { + Left(L), + Right(R), +} + +// ============================================================================ +// Core Traits +// ============================================================================ + +/// Equality comparison +pub trait Eq { + fn eq(self: ref Self, other: ref Self) -> Bool; + + fn ne(self: ref Self, other: ref Self) -> Bool { + !self.eq(other) + } +} + +/// Ordering comparison +pub trait Ord: Eq { + fn cmp(self: ref Self, other: ref Self) -> Ordering; + + fn lt(self: ref Self, other: ref Self) -> Bool { + match self.cmp(other) { + Ordering::Less => true, + _ => false, + } + } + + fn le(self: ref Self, other: ref Self) -> Bool { + match self.cmp(other) { + Ordering::Greater => false, + _ => true, + } + } + + fn gt(self: ref Self, other: ref Self) -> Bool { + match self.cmp(other) { + Ordering::Greater => true, + _ => false, + } + } + + fn ge(self: ref Self, other: ref Self) -> Bool { + match self.cmp(other) { + Ordering::Less => false, + _ => true, + } + } + + fn max(self: Self, other: Self) -> Self { + if self.ge(&other) { self } else { other } + } + + fn min(self: Self, other: Self) -> Self { + if self.le(&other) { self } else { other } + } +} + +/// Clone trait for types that can be duplicated +pub trait Clone { + fn clone(self: ref Self) -> Self; +} + +/// Default trait for types with a default value +pub trait Default { + fn default() -> Self; +} + +/// Display trait for human-readable formatting +pub trait Display { + fn fmt(self: ref Self) -> String; +} + +/// Debug trait for debug formatting +pub trait Debug { + fn debug_fmt(self: ref Self) -> String; +} + +/// Hash trait for hashable types +pub trait Hash { + fn hash(self: ref Self) -> Int; +} + +/// From trait for conversions +pub trait From[T] { + fn from(value: T) -> Self; +} + +/// Into trait (reciprocal of From) +pub trait Into[T] { + fn into(self: Self) -> T; +} + +/// Iterator trait +pub trait Iterator { + type Item; + + fn next(self: mut Self) -> Option[Self::Item]; + + fn map[U](self: Self, f: fn(Self::Item) -> U) -> Map[Self, fn(Self::Item) -> U] { + Map { iter: self, f: f } + } + + fn filter(self: Self, pred: fn(ref Self::Item) -> Bool) -> Filter[Self, fn(ref Self::Item) -> Bool] { + Filter { iter: self, pred: pred } + } + + fn fold[A](self: Self, init: A, f: fn(A, Self::Item) -> A) -> A { + let mut acc = init; + while let Some(item) = self.next() { + acc = f(acc, item); + } + acc + } + + fn collect[C: FromIterator[Self::Item]](self: Self) -> C { + C::from_iter(self) + } +} + +/// FromIterator trait for collecting iterators +pub trait FromIterator[T] { + fn from_iter[I: Iterator[Item = T]](iter: I) -> Self; +} + +// ============================================================================ +// Option Methods +// ============================================================================ + +impl[T] Option[T] { + /// Check if option contains a value + pub fn is_some(self: ref Self) -> Bool { + match self { + Some(_) => true, + None => false, + } + } + + /// Check if option is empty + pub fn is_none(self: ref Self) -> Bool { + match self { + None => true, + Some(_) => false, + } + } + + /// Unwrap the value or panic + pub fn unwrap(self: Self) -> T { + match self { + Some(x) => x, + None => panic("called unwrap on None"), + } + } + + /// Unwrap the value or return default + pub fn unwrap_or(self: Self, default: T) -> T { + match self { + Some(x) => x, + None => default, + } + } + + /// Unwrap the value or compute default + pub fn unwrap_or_else(self: Self, f: fn() -> T) -> T { + match self { + Some(x) => x, + None => f(), + } + } + + /// Map over the contained value + pub fn map[U](self: Self, f: fn(T) -> U) -> Option[U] { + match self { + Some(x) => Some(f(x)), + None => None, + } + } + + /// Flat map / and_then + pub fn and_then[U](self: Self, f: fn(T) -> Option[U]) -> Option[U] { + match self { + Some(x) => f(x), + None => None, + } + } + + /// Filter the option + pub fn filter(self: Self, pred: fn(ref T) -> Bool) -> Option[T] { + match self { + Some(ref x) if pred(x) => self, + _ => None, + } + } + + /// Get reference to contained value + pub fn as_ref(self: ref Self) -> Option[ref T] { + match self { + Some(ref x) => Some(x), + None => None, + } + } +} + +// ============================================================================ +// Result Methods +// ============================================================================ + +impl[T, E] Result[T, E] { + /// Check if result is Ok + pub fn is_ok(self: ref Self) -> Bool { + match self { + Ok(_) => true, + Err(_) => false, + } + } + + /// Check if result is Err + pub fn is_err(self: ref Self) -> Bool { + match self { + Err(_) => true, + Ok(_) => false, + } + } + + /// Unwrap Ok value or panic + pub fn unwrap(self: Self) -> T { + match self { + Ok(x) => x, + Err(_) => panic("called unwrap on Err"), + } + } + + /// Unwrap Err value or panic + pub fn unwrap_err(self: Self) -> E { + match self { + Err(e) => e, + Ok(_) => panic("called unwrap_err on Ok"), + } + } + + /// Unwrap Ok or return default + pub fn unwrap_or(self: Self, default: T) -> T { + match self { + Ok(x) => x, + Err(_) => default, + } + } + + /// Map over Ok value + pub fn map[U](self: Self, f: fn(T) -> U) -> Result[U, E] { + match self { + Ok(x) => Ok(f(x)), + Err(e) => Err(e), + } + } + + /// Map over Err value + pub fn map_err[F](self: Self, f: fn(E) -> F) -> Result[T, F] { + match self { + Ok(x) => Ok(x), + Err(e) => Err(f(e)), + } + } + + /// Flat map / and_then + pub fn and_then[U](self: Self, f: fn(T) -> Result[U, E]) -> Result[U, E] { + match self { + Ok(x) => f(x), + Err(e) => Err(e), + } + } + + /// Convert to Option (discards error) + pub fn ok(self: Self) -> Option[T] { + match self { + Ok(x) => Some(x), + Err(_) => None, + } + } + + /// Convert to Option (discards success) + pub fn err(self: Self) -> Option[E] { + match self { + Ok(_) => None, + Err(e) => Some(e), + } + } +} + +// ============================================================================ +// Utility Functions +// ============================================================================ + +/// Identity function +pub total fn identity[T](x: T) -> T { + x +} + +/// Constant function +pub total fn const_[T, U](x: T) -> fn(U) -> T { + \_ -> x +} + +/// Function composition +pub total fn compose[A, B, C](f: fn(B) -> C, g: fn(A) -> B) -> fn(A) -> C { + \x -> f(g(x)) +} + +/// Flip function arguments +pub total fn flip[A, B, C](f: fn(A, B) -> C) -> fn(B, A) -> C { + \b, a -> f(a, b) +} + +/// Panic with message +pub fn panic(msg: String) -> Never { + // Implementation provided by runtime + unreachable() +} + +/// Assert condition +pub fn assert(cond: Bool) { + if !cond { + panic("assertion failed"); + } +} + +/// Assert with message +pub fn assert_eq[T: Eq + Debug](left: T, right: T) { + if !left.eq(&right) { + panic("assertion failed: " + left.debug_fmt() + " != " + right.debug_fmt()); + } +} + +/// Debug print (for development) +pub fn dbg[T: Debug](value: T) -> T { + println(value.debug_fmt()); + value +} diff --git a/library/common/string.afs b/library/common/string.afs new file mode 100644 index 0000000..97b39a4 --- /dev/null +++ b/library/common/string.afs @@ -0,0 +1,550 @@ +// AffineScript Common Library - String +// String manipulation utilities + +module Common.String; + +use Common.Prelude::*; + +// ============================================================================ +// String Extensions +// ============================================================================ + +impl String { + /// Get length in bytes + pub fn len(self: ref Self) -> Int { + // Native implementation + 0 + } + + /// Check if string is empty + pub fn is_empty(self: ref Self) -> Bool { + self.len() == 0 + } + + /// Get character at index + pub fn char_at(self: ref Self, index: Int) -> Option[Char] { + if index >= 0 && index < self.len() { + Some(self.chars().nth(index).unwrap()) + } else { + None + } + } + + /// Get substring + pub fn substring(self: ref Self, start: Int, end: Int) -> String { + // Native implementation + "" + } + + /// Split by delimiter + pub fn split(self: ref Self, delimiter: String) -> Array[String] { + // Native implementation + [] + } + + /// Split by whitespace + pub fn split_whitespace(self: ref Self) -> Array[String] { + self.split(" ") + } + + /// Split into lines + pub fn lines(self: ref Self) -> Array[String] { + self.split("\n") + } + + /// Join array of strings with separator + pub fn join(parts: Array[String], separator: String) -> String { + if len(parts) == 0 { + "" + } else { + let mut result = parts[0]; + for i in 1..len(parts) { + result = result + separator + parts[i]; + } + result + } + } + + /// Check if string contains substring + pub fn contains(self: ref Self, needle: String) -> Bool { + self.find(needle).is_some() + } + + /// Find first occurrence of substring + pub fn find(self: ref Self, needle: String) -> Option[Int] { + // Native implementation + None + } + + /// Find last occurrence of substring + pub fn rfind(self: ref Self, needle: String) -> Option[Int] { + // Native implementation + None + } + + /// Check if string starts with prefix + pub fn starts_with(self: ref Self, prefix: String) -> Bool { + self.len() >= prefix.len() && + self.substring(0, prefix.len()) == prefix + } + + /// Check if string ends with suffix + pub fn ends_with(self: ref Self, suffix: String) -> Bool { + self.len() >= suffix.len() && + self.substring(self.len() - suffix.len(), self.len()) == suffix + } + + /// Strip prefix if present + pub fn strip_prefix(self: ref Self, prefix: String) -> Option[String] { + if self.starts_with(prefix) { + Some(self.substring(prefix.len(), self.len())) + } else { + None + } + } + + /// Strip suffix if present + pub fn strip_suffix(self: ref Self, suffix: String) -> Option[String] { + if self.ends_with(suffix) { + Some(self.substring(0, self.len() - suffix.len())) + } else { + None + } + } + + /// Trim whitespace from both ends + pub fn trim(self: ref Self) -> String { + self.trim_start().trim_end() + } + + /// Trim whitespace from start + pub fn trim_start(self: ref Self) -> String { + // Native implementation + *self + } + + /// Trim whitespace from end + pub fn trim_end(self: ref Self) -> String { + // Native implementation + *self + } + + /// Convert to uppercase + pub fn to_uppercase(self: ref Self) -> String { + // Native implementation + *self + } + + /// Convert to lowercase + pub fn to_lowercase(self: ref Self) -> String { + // Native implementation + *self + } + + /// Capitalize first character + pub fn capitalize(self: ref Self) -> String { + if self.is_empty() { + "" + } else { + self.char_at(0).unwrap().to_uppercase().to_string() + + self.substring(1, self.len()).to_lowercase() + } + } + + /// Replace all occurrences + pub fn replace(self: ref Self, from: String, to: String) -> String { + String::join(self.split(from), to) + } + + /// Replace first occurrence + pub fn replace_first(self: ref Self, from: String, to: String) -> String { + match self.find(from) { + Some(idx) => { + self.substring(0, idx) + to + self.substring(idx + from.len(), self.len()) + }, + None => *self, + } + } + + /// Repeat string n times + pub fn repeat(self: ref Self, n: Int) -> String { + let mut result = ""; + for _ in 0..n { + result = result + self; + } + result + } + + /// Pad start to reach length + pub fn pad_start(self: ref Self, length: Int, pad: Char) -> String { + if self.len() >= length { + *self + } else { + pad.to_string().repeat(length - self.len()) + self + } + } + + /// Pad end to reach length + pub fn pad_end(self: ref Self, length: Int, pad: Char) -> String { + if self.len() >= length { + *self + } else { + self + pad.to_string().repeat(length - self.len()) + } + } + + /// Reverse the string + pub fn reverse(self: ref Self) -> String { + let chars = self.chars().collect(); + let mut result = ""; + for i in (0..len(chars)).rev() { + result = result + chars[i].to_string(); + } + result + } + + /// Get iterator over characters + pub fn chars(self: ref Self) -> Chars { + Chars { string: self, index: 0 } + } + + /// Get iterator over bytes + pub fn bytes(self: ref Self) -> Bytes { + Bytes { string: self, index: 0 } + } + + /// Check if string matches pattern (simplified glob) + pub fn matches(self: ref Self, pattern: String) -> Bool { + // Simplified pattern matching + if pattern == "*" { + true + } else if pattern.starts_with("*") && pattern.ends_with("*") { + let middle = pattern.substring(1, pattern.len() - 1); + self.contains(middle) + } else if pattern.starts_with("*") { + self.ends_with(pattern.substring(1, pattern.len())) + } else if pattern.ends_with("*") { + self.starts_with(pattern.substring(0, pattern.len() - 1)) + } else { + *self == pattern + } + } + + /// Parse string to integer + pub fn parse_int(self: ref Self) -> Option[Int] { + // Native implementation + None + } + + /// Parse string to float + pub fn parse_float(self: ref Self) -> Option[Float] { + // Native implementation + None + } + + /// Parse string to bool + pub fn parse_bool(self: ref Self) -> Option[Bool] { + match self.to_lowercase() { + "true" | "yes" | "1" => Some(true), + "false" | "no" | "0" => Some(false), + _ => None, + } + } + + /// Check if all characters satisfy predicate + pub fn all(self: ref Self, pred: fn(Char) -> Bool) -> Bool { + for c in self.chars() { + if !pred(c) { + return false; + } + } + true + } + + /// Check if any character satisfies predicate + pub fn any(self: ref Self, pred: fn(Char) -> Bool) -> Bool { + for c in self.chars() { + if pred(c) { + return true; + } + } + false + } + + /// Check if string is alphanumeric + pub fn is_alphanumeric(self: ref Self) -> Bool { + !self.is_empty() && self.all(|c| c.is_alphanumeric()) + } + + /// Check if string is alphabetic + pub fn is_alphabetic(self: ref Self) -> Bool { + !self.is_empty() && self.all(|c| c.is_alphabetic()) + } + + /// Check if string is numeric + pub fn is_numeric(self: ref Self) -> Bool { + !self.is_empty() && self.all(|c| c.is_digit()) + } + + /// Check if string is whitespace + pub fn is_whitespace(self: ref Self) -> Bool { + !self.is_empty() && self.all(|c| c.is_whitespace()) + } +} + +// ============================================================================ +// Character Iterator +// ============================================================================ + +pub struct Chars { + string: ref String, + index: Int, +} + +impl Iterator for Chars { + type Item = Char; + + fn next(self: mut Self) -> Option[Char] { + if self.index < self.string.len() { + let c = self.string.char_at(self.index); + self.index += 1; + c + } else { + None + } + } +} + +// ============================================================================ +// Byte Iterator +// ============================================================================ + +pub struct Bytes { + string: ref String, + index: Int, +} + +impl Iterator for Bytes { + type Item = u8; + + fn next(self: mut Self) -> Option[u8] { + if self.index < self.string.len() { + // Would get actual byte + self.index += 1; + Some(0u8) + } else { + None + } + } +} + +// ============================================================================ +// Character Extensions +// ============================================================================ + +impl Char { + /// Check if character is alphabetic + pub fn is_alphabetic(self: Self) -> Bool { + (self >= 'a' && self <= 'z') || (self >= 'A' && self <= 'Z') + } + + /// Check if character is digit + pub fn is_digit(self: Self) -> Bool { + self >= '0' && self <= '9' + } + + /// Check if character is alphanumeric + pub fn is_alphanumeric(self: Self) -> Bool { + self.is_alphabetic() || self.is_digit() + } + + /// Check if character is whitespace + pub fn is_whitespace(self: Self) -> Bool { + self == ' ' || self == '\t' || self == '\n' || self == '\r' + } + + /// Check if character is uppercase + pub fn is_uppercase(self: Self) -> Bool { + self >= 'A' && self <= 'Z' + } + + /// Check if character is lowercase + pub fn is_lowercase(self: Self) -> Bool { + self >= 'a' && self <= 'z' + } + + /// Convert to uppercase + pub fn to_uppercase(self: Self) -> Char { + if self.is_lowercase() { + // ASCII conversion + ((self as u8) - 32) as Char + } else { + self + } + } + + /// Convert to lowercase + pub fn to_lowercase(self: Self) -> Char { + if self.is_uppercase() { + // ASCII conversion + ((self as u8) + 32) as Char + } else { + self + } + } + + /// Convert to string + pub fn to_string(self: Self) -> String { + // Native implementation + "" + } + + /// Get digit value (0-9) + pub fn digit_value(self: Self) -> Option[Int] { + if self.is_digit() { + Some((self as u8 - '0' as u8) as Int) + } else { + None + } + } + + /// Get hex digit value (0-15) + pub fn hex_digit_value(self: Self) -> Option[Int] { + if self.is_digit() { + Some((self as u8 - '0' as u8) as Int) + } else if self >= 'a' && self <= 'f' { + Some((self as u8 - 'a' as u8 + 10) as Int) + } else if self >= 'A' && self <= 'F' { + Some((self as u8 - 'A' as u8 + 10) as Int) + } else { + None + } + } +} + +// ============================================================================ +// String Builder +// ============================================================================ + +/// Efficient string building +pub struct StringBuilder { + parts: Vec[String], + len: Int, +} + +impl StringBuilder { + /// Create new empty builder + pub fn new() -> StringBuilder { + StringBuilder { parts: Vec::new(), len: 0 } + } + + /// Create with initial capacity + pub fn with_capacity(cap: Int) -> StringBuilder { + StringBuilder { parts: Vec::with_capacity(cap), len: 0 } + } + + /// Append string + pub fn push(self: mut Self, s: String) -> mut Self { + self.len += s.len(); + self.parts.push(s); + self + } + + /// Append character + pub fn push_char(self: mut Self, c: Char) -> mut Self { + self.push(c.to_string()) + } + + /// Append line with newline + pub fn push_line(self: mut Self, s: String) -> mut Self { + self.push(s).push("\n") + } + + /// Get current length + pub fn len(self: ref Self) -> Int { + self.len + } + + /// Check if empty + pub fn is_empty(self: ref Self) -> Bool { + self.len == 0 + } + + /// Clear contents + pub fn clear(self: mut Self) { + self.parts.clear(); + self.len = 0; + } + + /// Build final string + pub fn build(self: Self) -> String { + String::join(self.parts.to_array(), "") + } + + /// Build and clear + pub fn take(self: mut Self) -> String { + let result = self.build(); + self.clear(); + result + } +} + +impl Display for StringBuilder { + fn fmt(self: ref Self) -> String { + self.build() + } +} + +// ============================================================================ +// Formatting Utilities +// ============================================================================ + +/// Format integer with thousands separator +pub fn format_number(n: Int, separator: String) -> String { + let s = n.to_string(); + let negative = n < 0; + let digits = if negative { s.substring(1, s.len()) } else { s }; + + let mut result = ""; + let len = digits.len(); + for i in 0..len { + if i > 0 && (len - i) % 3 == 0 { + result = result + separator; + } + result = result + digits.char_at(i).unwrap().to_string(); + } + + if negative { "-" + result } else { result } +} + +/// Format with placeholders: format("Hello, {}!", ["world"]) +pub fn format(template: String, args: Array[String]) -> String { + let mut result = template; + for arg in args { + result = result.replace_first("{}", arg); + } + result +} + +/// Left-justify string in field +pub fn ljust(s: String, width: Int) -> String { + s.pad_end(width, ' ') +} + +/// Right-justify string in field +pub fn rjust(s: String, width: Int) -> String { + s.pad_start(width, ' ') +} + +/// Center string in field +pub fn center(s: String, width: Int) -> String { + if s.len() >= width { + s + } else { + let padding = width - s.len(); + let left = padding / 2; + let right = padding - left; + " ".repeat(left) + s + " ".repeat(right) + } +} + diff --git a/library/common/sync.afs b/library/common/sync.afs new file mode 100644 index 0000000..5800d4b --- /dev/null +++ b/library/common/sync.afs @@ -0,0 +1,524 @@ +// AffineScript Common Library - Sync +// Synchronization primitives (non-async) + +module Common.Sync; + +use Common.Prelude::*; + +// ============================================================================ +// Atomic Types +// ============================================================================ + +/// Atomic boolean +pub struct AtomicBool { + // Implementation hidden - native atomic + value: Bool, +} + +impl AtomicBool { + pub fn new(value: Bool) -> AtomicBool { + AtomicBool { value } + } + + pub fn load(self: ref Self) -> Bool { + // Atomic load + self.value + } + + pub fn store(self: ref Self, value: Bool) { + // Atomic store + // Native implementation + } + + pub fn swap(self: ref Self, value: Bool) -> Bool { + let old = self.value; + // Atomic swap + old + } + + pub fn compare_exchange(self: ref Self, current: Bool, new: Bool) -> Result[Bool, Bool] { + if self.value == current { + // Atomic CAS + Ok(current) + } else { + Err(self.value) + } + } + + pub fn fetch_and(self: ref Self, value: Bool) -> Bool { + let old = self.value; + // Atomic AND + old + } + + pub fn fetch_or(self: ref Self, value: Bool) -> Bool { + let old = self.value; + // Atomic OR + old + } + + pub fn fetch_xor(self: ref Self, value: Bool) -> Bool { + let old = self.value; + // Atomic XOR + old + } +} + +/// Atomic integer +pub struct AtomicInt { + value: Int, +} + +impl AtomicInt { + pub fn new(value: Int) -> AtomicInt { + AtomicInt { value } + } + + pub fn load(self: ref Self) -> Int { + self.value + } + + pub fn store(self: ref Self, value: Int) { + // Atomic store + } + + pub fn swap(self: ref Self, value: Int) -> Int { + let old = self.value; + old + } + + pub fn compare_exchange(self: ref Self, current: Int, new: Int) -> Result[Int, Int] { + if self.value == current { + Ok(current) + } else { + Err(self.value) + } + } + + pub fn fetch_add(self: ref Self, value: Int) -> Int { + let old = self.value; + // Atomic add + old + } + + pub fn fetch_sub(self: ref Self, value: Int) -> Int { + let old = self.value; + // Atomic sub + old + } + + pub fn fetch_max(self: ref Self, value: Int) -> Int { + let old = self.value; + // Atomic max + old + } + + pub fn fetch_min(self: ref Self, value: Int) -> Int { + let old = self.value; + // Atomic min + old + } +} + +// ============================================================================ +// Once (One-time initialization) +// ============================================================================ + +/// Ensures a function runs exactly once +pub struct Once { + done: AtomicBool, +} + +impl Once { + pub fn new() -> Once { + Once { done: AtomicBool::new(false) } + } + + /// Run function if not already run + pub fn call_once(self: ref Self, f: fn() -> ()) { + if !self.done.swap(true) { + f(); + } + } + + /// Check if already initialized + pub fn is_completed(self: ref Self) -> Bool { + self.done.load() + } +} + +/// Lazy initialization +pub struct Lazy[T] { + once: Once, + value: Option[T], + init: fn() -> T, +} + +impl[T] Lazy[T] { + pub fn new(init: fn() -> T) -> Lazy[T] { + Lazy { + once: Once::new(), + value: None, + init, + } + } + + pub fn get(self: mut Self) -> ref T { + self.once.call_once(|| { + self.value = Some((self.init)()); + }); + self.value.as_ref().unwrap() + } + + pub fn is_initialized(self: ref Self) -> Bool { + self.once.is_completed() + } +} + +// ============================================================================ +// Cell (Interior Mutability) +// ============================================================================ + +/// Single-threaded interior mutability +pub struct Cell[T: Copy] { + value: T, +} + +impl[T: Copy] Cell[T] { + pub fn new(value: T) -> Cell[T] { + Cell { value } + } + + pub fn get(self: ref Self) -> T { + self.value + } + + pub fn set(self: ref Self, value: T) { + // Interior mutation + } + + pub fn replace(self: ref Self, value: T) -> T { + let old = self.value; + self.set(value); + old + } + + pub fn swap(self: ref Self, other: ref Cell[T]) { + let tmp = self.get(); + self.set(other.get()); + other.set(tmp); + } +} + +/// Reference-counted interior mutability +pub struct RefCell[T] { + value: T, + borrow_state: Int, // 0 = not borrowed, >0 = shared borrows, -1 = mut borrowed +} + +impl[T] RefCell[T] { + pub fn new(value: T) -> RefCell[T] { + RefCell { value, borrow_state: 0 } + } + + /// Try to borrow immutably + pub fn try_borrow(self: ref Self) -> Option[Ref[T]] { + if self.borrow_state >= 0 { + Some(Ref { cell: self }) + } else { + None + } + } + + /// Borrow immutably (panics if already mutably borrowed) + pub fn borrow(self: ref Self) -> Ref[T] { + self.try_borrow().expect("already mutably borrowed") + } + + /// Try to borrow mutably + pub fn try_borrow_mut(self: ref Self) -> Option[RefMut[T]] { + if self.borrow_state == 0 { + Some(RefMut { cell: self }) + } else { + None + } + } + + /// Borrow mutably (panics if already borrowed) + pub fn borrow_mut(self: ref Self) -> RefMut[T] { + self.try_borrow_mut().expect("already borrowed") + } +} + +/// Immutable borrow guard +pub struct Ref[T] { + cell: ref RefCell[T], +} + +impl[T] Ref[T] { + pub fn get(self: ref Self) -> ref T { + &self.cell.value + } +} + +/// Mutable borrow guard +pub struct RefMut[T] { + cell: ref RefCell[T], +} + +impl[T] RefMut[T] { + pub fn get(self: ref Self) -> ref T { + &self.cell.value + } + + pub fn get_mut(self: mut Self) -> mut T { + &mut self.cell.value + } +} + +// ============================================================================ +// Spinlock (Simple busy-wait lock) +// ============================================================================ + +/// Simple spinlock +pub struct SpinLock[T] { + locked: AtomicBool, + value: T, +} + +impl[T] SpinLock[T] { + pub fn new(value: T) -> SpinLock[T] { + SpinLock { + locked: AtomicBool::new(false), + value, + } + } + + /// Lock and get guard + pub fn lock(self: ref Self) -> SpinLockGuard[T] { + while self.locked.swap(true) { + // Spin + } + SpinLockGuard { lock: self } + } + + /// Try to lock without spinning + pub fn try_lock(self: ref Self) -> Option[SpinLockGuard[T]] { + if !self.locked.swap(true) { + Some(SpinLockGuard { lock: self }) + } else { + None + } + } + + /// Check if locked + pub fn is_locked(self: ref Self) -> Bool { + self.locked.load() + } +} + +pub struct SpinLockGuard[T] { + lock: ref SpinLock[T], +} + +impl[T] SpinLockGuard[T] { + pub fn get(self: ref Self) -> ref T { + &self.lock.value + } + + pub fn get_mut(self: mut Self) -> mut T { + &mut self.lock.value + } +} + +// Guard releases lock when dropped +impl[T] Drop for SpinLockGuard[T] { + fn drop(self: mut Self) { + self.lock.locked.store(false); + } +} + +// ============================================================================ +// RwLock (Reader-Writer Lock) +// ============================================================================ + +/// Reader-writer lock (multiple readers OR one writer) +pub struct RwLock[T] { + /// Number of readers, or -1 if write-locked + state: AtomicInt, + value: T, +} + +impl[T] RwLock[T] { + pub fn new(value: T) -> RwLock[T] { + RwLock { + state: AtomicInt::new(0), + value, + } + } + + /// Acquire read lock + pub fn read(self: ref Self) -> RwLockReadGuard[T] { + loop { + let state = self.state.load(); + if state >= 0 { + if self.state.compare_exchange(state, state + 1).is_ok() { + break; + } + } + } + RwLockReadGuard { lock: self } + } + + /// Try to acquire read lock + pub fn try_read(self: ref Self) -> Option[RwLockReadGuard[T]] { + let state = self.state.load(); + if state >= 0 { + if self.state.compare_exchange(state, state + 1).is_ok() { + Some(RwLockReadGuard { lock: self }) + } else { + None + } + } else { + None + } + } + + /// Acquire write lock + pub fn write(self: ref Self) -> RwLockWriteGuard[T] { + loop { + if self.state.compare_exchange(0, -1).is_ok() { + break; + } + } + RwLockWriteGuard { lock: self } + } + + /// Try to acquire write lock + pub fn try_write(self: ref Self) -> Option[RwLockWriteGuard[T]] { + if self.state.compare_exchange(0, -1).is_ok() { + Some(RwLockWriteGuard { lock: self }) + } else { + None + } + } +} + +pub struct RwLockReadGuard[T] { + lock: ref RwLock[T], +} + +impl[T] RwLockReadGuard[T] { + pub fn get(self: ref Self) -> ref T { + &self.lock.value + } +} + +impl[T] Drop for RwLockReadGuard[T] { + fn drop(self: mut Self) { + self.lock.state.fetch_sub(1); + } +} + +pub struct RwLockWriteGuard[T] { + lock: ref RwLock[T], +} + +impl[T] RwLockWriteGuard[T] { + pub fn get(self: ref Self) -> ref T { + &self.lock.value + } + + pub fn get_mut(self: mut Self) -> mut T { + &mut self.lock.value + } +} + +impl[T] Drop for RwLockWriteGuard[T] { + fn drop(self: mut Self) { + self.lock.state.store(0); + } +} + +// ============================================================================ +// Barrier +// ============================================================================ + +/// Thread barrier for synchronization +pub struct Barrier { + count: Int, + waiting: AtomicInt, + generation: AtomicInt, +} + +impl Barrier { + pub fn new(count: Int) -> Barrier { + Barrier { + count, + waiting: AtomicInt::new(0), + generation: AtomicInt::new(0), + } + } + + /// Wait at barrier until all threads arrive + pub fn wait(self: ref Self) -> BarrierWaitResult { + let gen = self.generation.load(); + let arrived = self.waiting.fetch_add(1) + 1; + + if arrived == self.count { + // Last to arrive - release everyone + self.waiting.store(0); + self.generation.fetch_add(1); + BarrierWaitResult { is_leader: true } + } else { + // Wait for generation to change + while self.generation.load() == gen { + // Spin + } + BarrierWaitResult { is_leader: false } + } + } +} + +pub struct BarrierWaitResult { + is_leader: Bool, +} + +impl BarrierWaitResult { + pub fn is_leader(self: ref Self) -> Bool { + self.is_leader + } +} + +// ============================================================================ +// CountDownLatch +// ============================================================================ + +/// Latch that counts down to zero +pub struct CountDownLatch { + count: AtomicInt, +} + +impl CountDownLatch { + pub fn new(count: Int) -> CountDownLatch { + CountDownLatch { count: AtomicInt::new(count) } + } + + /// Decrement count + pub fn count_down(self: ref Self) { + self.count.fetch_sub(1); + } + + /// Wait until count reaches zero + pub fn wait(self: ref Self) { + while self.count.load() > 0 { + // Spin + } + } + + /// Get current count + pub fn count(self: ref Self) -> Int { + self.count.load() + } +} + diff --git a/library/common/time.afs b/library/common/time.afs new file mode 100644 index 0000000..46ce210 --- /dev/null +++ b/library/common/time.afs @@ -0,0 +1,571 @@ +// AffineScript Common Library - Time +// Time and duration handling + +module Common.Time; + +use Common.Prelude::*; + +// ============================================================================ +// Duration +// ============================================================================ + +/// A duration of time (in nanoseconds internally) +pub struct Duration { + nanos: i64, +} + +impl Duration { + /// Zero duration + pub const ZERO: Duration = Duration { nanos: 0 }; + + /// Maximum duration + pub const MAX: Duration = Duration { nanos: 9223372036854775807 }; + + /// Create duration from nanoseconds + pub fn from_nanos(nanos: i64) -> Duration { + Duration { nanos } + } + + /// Create duration from microseconds + pub fn from_micros(micros: i64) -> Duration { + Duration { nanos: micros * 1000 } + } + + /// Create duration from milliseconds + pub fn from_millis(millis: i64) -> Duration { + Duration { nanos: millis * 1_000_000 } + } + + /// Create duration from seconds + pub fn from_secs(secs: i64) -> Duration { + Duration { nanos: secs * 1_000_000_000 } + } + + /// Create duration from minutes + pub fn from_mins(mins: i64) -> Duration { + Duration::from_secs(mins * 60) + } + + /// Create duration from hours + pub fn from_hours(hours: i64) -> Duration { + Duration::from_mins(hours * 60) + } + + /// Create duration from days + pub fn from_days(days: i64) -> Duration { + Duration::from_hours(days * 24) + } + + /// Create duration from seconds with fractional part + pub fn from_secs_f64(secs: Float) -> Duration { + Duration { nanos: (secs * 1_000_000_000.0) as i64 } + } + + /// Get total nanoseconds + pub fn as_nanos(self: ref Self) -> i64 { + self.nanos + } + + /// Get total microseconds + pub fn as_micros(self: ref Self) -> i64 { + self.nanos / 1000 + } + + /// Get total milliseconds + pub fn as_millis(self: ref Self) -> i64 { + self.nanos / 1_000_000 + } + + /// Get total seconds + pub fn as_secs(self: ref Self) -> i64 { + self.nanos / 1_000_000_000 + } + + /// Get as floating-point seconds + pub fn as_secs_f64(self: ref Self) -> Float { + (self.nanos as Float) / 1_000_000_000.0 + } + + /// Get subsecond nanoseconds + pub fn subsec_nanos(self: ref Self) -> i32 { + (self.nanos % 1_000_000_000) as i32 + } + + /// Get subsecond microseconds + pub fn subsec_micros(self: ref Self) -> i32 { + (self.nanos % 1_000_000_000 / 1000) as i32 + } + + /// Get subsecond milliseconds + pub fn subsec_millis(self: ref Self) -> i32 { + (self.nanos % 1_000_000_000 / 1_000_000) as i32 + } + + /// Check if zero + pub fn is_zero(self: ref Self) -> Bool { + self.nanos == 0 + } + + /// Checked addition + pub fn checked_add(self: ref Self, other: Duration) -> Option[Duration] { + // Would check for overflow + Some(Duration { nanos: self.nanos + other.nanos }) + } + + /// Checked subtraction + pub fn checked_sub(self: ref Self, other: Duration) -> Option[Duration] { + if other.nanos > self.nanos { + None + } else { + Some(Duration { nanos: self.nanos - other.nanos }) + } + } + + /// Checked multiplication + pub fn checked_mul(self: ref Self, rhs: i32) -> Option[Duration] { + Some(Duration { nanos: self.nanos * (rhs as i64) }) + } + + /// Saturating addition + pub fn saturating_add(self: ref Self, other: Duration) -> Duration { + self.checked_add(other).unwrap_or(Duration::MAX) + } + + /// Saturating subtraction + pub fn saturating_sub(self: ref Self, other: Duration) -> Duration { + self.checked_sub(other).unwrap_or(Duration::ZERO) + } + + /// Multiply by scalar + pub fn mul(self: ref Self, rhs: i32) -> Duration { + Duration { nanos: self.nanos * (rhs as i64) } + } + + /// Divide by scalar + pub fn div(self: ref Self, rhs: i32) -> Duration { + Duration { nanos: self.nanos / (rhs as i64) } + } +} + +impl Eq for Duration { + fn eq(self: ref Self, other: ref Self) -> Bool { + self.nanos == other.nanos + } +} + +impl Ord for Duration { + fn cmp(self: ref Self, other: ref Self) -> Ordering { + if self.nanos < other.nanos { Less } + else if self.nanos > other.nanos { Greater } + else { Equal } + } +} + +impl Display for Duration { + fn fmt(self: ref Self) -> String { + let secs = self.as_secs(); + let millis = self.subsec_millis(); + if secs > 0 { + secs.to_string() + "." + millis.to_string().pad_start(3, '0') + "s" + } else { + millis.to_string() + "ms" + } + } +} + +// ============================================================================ +// Instant +// ============================================================================ + +/// A point in time (monotonic clock) +pub struct Instant { + nanos: i64, +} + +impl Instant { + /// Get current instant + pub fn now() -> Instant / Time { + Time::now() + } + + /// Duration since this instant + pub fn elapsed(self: ref Self) -> Duration / Time { + Instant::now().duration_since(*self) + } + + /// Duration between two instants + pub fn duration_since(self: ref Self, earlier: Instant) -> Duration { + Duration { nanos: self.nanos - earlier.nanos } + } + + /// Checked duration since + pub fn checked_duration_since(self: ref Self, earlier: Instant) -> Option[Duration] { + if self.nanos >= earlier.nanos { + Some(Duration { nanos: self.nanos - earlier.nanos }) + } else { + None + } + } + + /// Saturating duration since + pub fn saturating_duration_since(self: ref Self, earlier: Instant) -> Duration { + self.checked_duration_since(earlier).unwrap_or(Duration::ZERO) + } + + /// Add duration + pub fn add(self: ref Self, duration: Duration) -> Instant { + Instant { nanos: self.nanos + duration.nanos } + } + + /// Subtract duration + pub fn sub(self: ref Self, duration: Duration) -> Instant { + Instant { nanos: self.nanos - duration.nanos } + } +} + +impl Eq for Instant { + fn eq(self: ref Self, other: ref Self) -> Bool { + self.nanos == other.nanos + } +} + +impl Ord for Instant { + fn cmp(self: ref Self, other: ref Self) -> Ordering { + if self.nanos < other.nanos { Less } + else if self.nanos > other.nanos { Greater } + else { Equal } + } +} + +// ============================================================================ +// Time Effect +// ============================================================================ + +/// Time effect for time-related operations +pub effect Time { + /// Get current monotonic instant + fn now() -> Instant; + + /// Get current system time + fn system_time() -> SystemTime; + + /// Sleep for duration + fn sleep(duration: Duration) -> (); +} + +// ============================================================================ +// System Time (Wall Clock) +// ============================================================================ + +/// System time (wall clock, can go backwards) +pub struct SystemTime { + /// Seconds since Unix epoch + secs: i64, + /// Nanoseconds within second + nanos: i32, +} + +impl SystemTime { + /// Unix epoch (1970-01-01 00:00:00 UTC) + pub const UNIX_EPOCH: SystemTime = SystemTime { secs: 0, nanos: 0 }; + + /// Get current system time + pub fn now() -> SystemTime / Time { + Time::system_time() + } + + /// Duration since earlier time + pub fn duration_since(self: ref Self, earlier: SystemTime) -> Result[Duration, TimeError] { + let secs_diff = self.secs - earlier.secs; + let nanos_diff = (self.nanos - earlier.nanos) as i64; + let total_nanos = secs_diff * 1_000_000_000 + nanos_diff; + if total_nanos >= 0 { + Ok(Duration { nanos: total_nanos }) + } else { + Err(TimeError::SystemTimeError) + } + } + + /// Duration since Unix epoch + pub fn duration_since_epoch(self: ref Self) -> Duration { + self.duration_since(SystemTime::UNIX_EPOCH).unwrap() + } + + /// Elapsed since this time + pub fn elapsed(self: ref Self) -> Result[Duration, TimeError] / Time { + SystemTime::now().duration_since(*self) + } + + /// Add duration + pub fn add(self: ref Self, duration: Duration) -> SystemTime { + let total_nanos = self.nanos as i64 + duration.nanos; + let extra_secs = total_nanos / 1_000_000_000; + let remaining_nanos = total_nanos % 1_000_000_000; + SystemTime { + secs: self.secs + extra_secs, + nanos: remaining_nanos as i32, + } + } + + /// Subtract duration + pub fn sub(self: ref Self, duration: Duration) -> SystemTime { + let total_nanos = self.nanos as i64 - duration.nanos; + if total_nanos >= 0 { + SystemTime { + secs: self.secs, + nanos: total_nanos as i32, + } + } else { + let borrow_secs = (-total_nanos - 1) / 1_000_000_000 + 1; + SystemTime { + secs: self.secs - borrow_secs, + nanos: (total_nanos + borrow_secs * 1_000_000_000) as i32, + } + } + } +} + +/// Time-related errors +pub enum TimeError { + /// System time went backwards + SystemTimeError, +} + +// ============================================================================ +// DateTime (Simplified) +// ============================================================================ + +/// A date and time (simplified, no timezone) +pub struct DateTime { + year: i32, + month: u8, // 1-12 + day: u8, // 1-31 + hour: u8, // 0-23 + minute: u8, // 0-59 + second: u8, // 0-59 + nanosecond: i32, +} + +impl DateTime { + /// Create from components + pub fn new(year: i32, month: u8, day: u8, hour: u8, minute: u8, second: u8) -> DateTime { + DateTime { year, month, day, hour, minute, second, nanosecond: 0 } + } + + /// Create from Unix timestamp + pub fn from_timestamp(secs: i64) -> DateTime { + // Simplified - would do proper calendar calculation + let days = secs / 86400; + let time_secs = secs % 86400; + DateTime { + year: 1970 + (days / 365) as i32, + month: 1, + day: 1, + hour: (time_secs / 3600) as u8, + minute: ((time_secs % 3600) / 60) as u8, + second: (time_secs % 60) as u8, + nanosecond: 0, + } + } + + /// Get year + pub fn year(self: ref Self) -> i32 { self.year } + + /// Get month (1-12) + pub fn month(self: ref Self) -> u8 { self.month } + + /// Get day of month (1-31) + pub fn day(self: ref Self) -> u8 { self.day } + + /// Get hour (0-23) + pub fn hour(self: ref Self) -> u8 { self.hour } + + /// Get minute (0-59) + pub fn minute(self: ref Self) -> u8 { self.minute } + + /// Get second (0-59) + pub fn second(self: ref Self) -> u8 { self.second } + + /// Format as ISO 8601 string + pub fn to_iso_string(self: ref Self) -> String { + format( + "{}-{}-{}T{}:{}:{}", + [ + self.year.to_string().pad_start(4, '0'), + self.month.to_string().pad_start(2, '0'), + self.day.to_string().pad_start(2, '0'), + self.hour.to_string().pad_start(2, '0'), + self.minute.to_string().pad_start(2, '0'), + self.second.to_string().pad_start(2, '0'), + ] + ) + } + + /// Format as date only + pub fn to_date_string(self: ref Self) -> String { + format( + "{}-{}-{}", + [ + self.year.to_string().pad_start(4, '0'), + self.month.to_string().pad_start(2, '0'), + self.day.to_string().pad_start(2, '0'), + ] + ) + } + + /// Format as time only + pub fn to_time_string(self: ref Self) -> String { + format( + "{}:{}:{}", + [ + self.hour.to_string().pad_start(2, '0'), + self.minute.to_string().pad_start(2, '0'), + self.second.to_string().pad_start(2, '0'), + ] + ) + } +} + +impl Display for DateTime { + fn fmt(self: ref Self) -> String { + self.to_iso_string() + } +} + +// ============================================================================ +// Timing Utilities +// ============================================================================ + +/// Measure execution time of a function +pub fn measure[T](f: fn() -> T) -> (T, Duration) / Time { + let start = Instant::now(); + let result = f(); + let elapsed = start.elapsed(); + (result, elapsed) +} + +/// Time a function and print result +pub fn time_it[T](name: String, f: fn() -> T) -> T / Time, Console { + let (result, elapsed) = measure(f); + Console::println(format("{}: {}", [name, elapsed.to_string()])); + result +} + +/// Simple stopwatch +pub struct Stopwatch { + start: Option[Instant], + accumulated: Duration, + running: Bool, +} + +impl Stopwatch { + /// Create new stopped stopwatch + pub fn new() -> Stopwatch { + Stopwatch { + start: None, + accumulated: Duration::ZERO, + running: false, + } + } + + /// Create and start stopwatch + pub fn start_new() -> Stopwatch / Time { + let mut sw = Stopwatch::new(); + sw.start(); + sw + } + + /// Start or resume + pub fn start(self: mut Self) / Time { + if !self.running { + self.start = Some(Instant::now()); + self.running = true; + } + } + + /// Stop + pub fn stop(self: mut Self) / Time { + if self.running { + if let Some(start) = self.start { + self.accumulated = self.accumulated.saturating_add(start.elapsed()); + } + self.start = None; + self.running = false; + } + } + + /// Reset to zero + pub fn reset(self: mut Self) { + self.start = None; + self.accumulated = Duration::ZERO; + self.running = false; + } + + /// Restart from zero + pub fn restart(self: mut Self) / Time { + self.reset(); + self.start(); + } + + /// Get elapsed time + pub fn elapsed(self: ref Self) -> Duration / Time { + if self.running { + if let Some(start) = self.start { + self.accumulated.saturating_add(start.elapsed()) + } else { + self.accumulated + } + } else { + self.accumulated + } + } + + /// Check if running + pub fn is_running(self: ref Self) -> Bool { + self.running + } +} + +/// Lap timer for tracking splits +pub struct LapTimer { + stopwatch: Stopwatch, + laps: Vec[Duration], +} + +impl LapTimer { + pub fn new() -> LapTimer { + LapTimer { + stopwatch: Stopwatch::new(), + laps: Vec::new(), + } + } + + pub fn start(self: mut Self) / Time { + self.stopwatch.start(); + } + + pub fn lap(self: mut Self) -> Duration / Time { + let elapsed = self.stopwatch.elapsed(); + let last_total: Duration = if self.laps.is_empty() { + Duration::ZERO + } else { + let mut sum = Duration::ZERO; + for d in &self.laps { + sum = sum.saturating_add(*d); + } + sum + }; + let lap_time = elapsed.saturating_sub(last_total); + self.laps.push(lap_time); + lap_time + } + + pub fn total(self: ref Self) -> Duration / Time { + self.stopwatch.elapsed() + } + + pub fn laps(self: ref Self) -> ref Vec[Duration] { + &self.laps + } +} + diff --git a/test/test_eval.ml b/test/test_eval.ml new file mode 100644 index 0000000..84daedc --- /dev/null +++ b/test/test_eval.ml @@ -0,0 +1,285 @@ +(** Tests for the AffineScript interpreter *) + +open Affinescript + +let eval s = + let env = Value.empty_env () in + Stdlib.load_prelude env; + match Repl.eval_string ~env s with + | Ok v -> v + | Error msg -> failwith msg + +let eval_to_int s = + match eval s with + | Value.VInt i -> i + | v -> failwith (Printf.sprintf "Expected int, got %s" (Value.show v)) + +let eval_to_bool s = + match eval s with + | Value.VBool b -> b + | v -> failwith (Printf.sprintf "Expected bool, got %s" (Value.show v)) + +let eval_to_string s = + match eval s with + | Value.VString s -> s + | v -> failwith (Printf.sprintf "Expected string, got %s" (Value.show v)) + +(* ========== Literal Tests ========== *) + +let test_int_literal () = + Alcotest.(check int) "int literal" 42 (eval_to_int "42") + +let test_negative_int () = + Alcotest.(check int) "negative int" (-5) (eval_to_int "-5") + +let test_bool_literal () = + Alcotest.(check bool) "true" true (eval_to_bool "true"); + Alcotest.(check bool) "false" false (eval_to_bool "false") + +let test_string_literal () = + Alcotest.(check string) "string" "hello" (eval_to_string "\"hello\"") + +(* ========== Arithmetic Tests ========== *) + +let test_addition () = + Alcotest.(check int) "1 + 2" 3 (eval_to_int "1 + 2") + +let test_subtraction () = + Alcotest.(check int) "5 - 3" 2 (eval_to_int "5 - 3") + +let test_multiplication () = + Alcotest.(check int) "4 * 5" 20 (eval_to_int "4 * 5") + +let test_division () = + Alcotest.(check int) "10 / 3" 3 (eval_to_int "10 / 3") + +let test_modulo () = + Alcotest.(check int) "10 % 3" 1 (eval_to_int "10 % 3") + +let test_precedence () = + Alcotest.(check int) "1 + 2 * 3" 7 (eval_to_int "1 + 2 * 3"); + Alcotest.(check int) "(1 + 2) * 3" 9 (eval_to_int "(1 + 2) * 3") + +let test_complex_expr () = + Alcotest.(check int) "complex" 14 (eval_to_int "2 * 3 + 4 * 2") + +(* ========== Comparison Tests ========== *) + +let test_equality () = + Alcotest.(check bool) "1 == 1" true (eval_to_bool "1 == 1"); + Alcotest.(check bool) "1 == 2" false (eval_to_bool "1 == 2") + +let test_inequality () = + Alcotest.(check bool) "1 != 2" true (eval_to_bool "1 != 2"); + Alcotest.(check bool) "1 != 1" false (eval_to_bool "1 != 1") + +let test_less_than () = + Alcotest.(check bool) "1 < 2" true (eval_to_bool "1 < 2"); + Alcotest.(check bool) "2 < 1" false (eval_to_bool "2 < 1") + +let test_greater_than () = + Alcotest.(check bool) "2 > 1" true (eval_to_bool "2 > 1"); + Alcotest.(check bool) "1 > 2" false (eval_to_bool "1 > 2") + +(* ========== Logical Tests ========== *) + +let test_and () = + Alcotest.(check bool) "true && true" true (eval_to_bool "true && true"); + Alcotest.(check bool) "true && false" false (eval_to_bool "true && false") + +let test_or () = + Alcotest.(check bool) "false || true" true (eval_to_bool "false || true"); + Alcotest.(check bool) "false || false" false (eval_to_bool "false || false") + +let test_not () = + Alcotest.(check bool) "!true" false (eval_to_bool "!true"); + Alcotest.(check bool) "!false" true (eval_to_bool "!false") + +(* ========== Let Binding Tests ========== *) + +let test_let_binding () = + Alcotest.(check int) "let x = 5; x" 5 (eval_to_int "{ let x = 5; x }") + +let test_let_shadowing () = + Alcotest.(check int) "shadowing" 10 + (eval_to_int "{ let x = 5; let x = 10; x }") + +let test_let_in_expr () = + Alcotest.(check int) "let in expr" 15 + (eval_to_int "{ let x = 5; let y = 10; x + y }") + +(* ========== If Expression Tests ========== *) + +let test_if_true () = + Alcotest.(check int) "if true" 1 (eval_to_int "if true { 1 } else { 2 }") + +let test_if_false () = + Alcotest.(check int) "if false" 2 (eval_to_int "if false { 1 } else { 2 }") + +let test_nested_if () = + Alcotest.(check int) "nested if" 3 + (eval_to_int "if false { 1 } else if false { 2 } else { 3 }") + +(* ========== Function Tests ========== *) + +let test_function_def () = + Alcotest.(check int) "fn call" 5 + (eval_to_int "{ fn add(a: Int, b: Int) -> Int { a + b } add(2, 3) }") + +let test_recursive_fn () = + Alcotest.(check int) "recursive" 120 + (eval_to_int {|{ + fn fact(n: Int) -> Int { + if n <= 1 { 1 } else { n * fact(n - 1) } + } + fact(5) + }|}) + +let test_higher_order () = + Alcotest.(check int) "higher order" 10 + (eval_to_int {|{ + fn apply(f: fn(Int) -> Int, x: Int) -> Int { f(x) } + fn double(x: Int) -> Int { x * 2 } + apply(double, 5) + }|}) + +let test_lambda () = + Alcotest.(check int) "lambda" 15 + (eval_to_int {|{ + let f = \x: Int -> x * 3; + f(5) + }|}) + +(* ========== Tuple Tests ========== *) + +let test_tuple () = + Alcotest.(check int) "tuple.0" 1 (eval_to_int "(1, 2, 3).0"); + Alcotest.(check int) "tuple.1" 2 (eval_to_int "(1, 2, 3).1") + +(* ========== Array Tests ========== *) + +let test_array_index () = + Alcotest.(check int) "array[1]" 20 (eval_to_int "[10, 20, 30][1]") + +let test_array_len () = + Alcotest.(check int) "len([1,2,3])" 3 (eval_to_int "len([1, 2, 3])") + +(* ========== Record Tests ========== *) + +let test_record () = + Alcotest.(check int) "record.x" 10 (eval_to_int "{x: 10, y: 20}.x") + +let test_record_shorthand () = + Alcotest.(check int) "shorthand" 5 (eval_to_int "{ let x = 5; {x}.x }") + +(* ========== Match Tests ========== *) + +let test_match_int () = + Alcotest.(check int) "match int" 10 + (eval_to_int "match 1 { 1 => 10, 2 => 20, _ => 0 }") + +let test_match_tuple () = + Alcotest.(check int) "match tuple" 3 + (eval_to_int "match (1, 2) { (a, b) => a + b }") + +(* ========== Loop Tests ========== *) + +let test_while_loop () = + Alcotest.(check int) "while" 10 + (eval_to_int {|{ + let mut x = 0; + while x < 10 { x += 1; } + x + }|}) + +let test_for_loop () = + Alcotest.(check int) "for" 6 + (eval_to_int {|{ + let mut sum = 0; + for i in [1, 2, 3] { sum += i; } + sum + }|}) + +(* ========== Builtin Tests ========== *) + +let test_builtin_range () = + Alcotest.(check int) "range len" 5 (eval_to_int "len(range(5))") + +let test_builtin_map () = + Alcotest.(check int) "map" 4 + (eval_to_int "map(\\x: Int -> x * 2, [1, 2])[1]") + +let test_builtin_filter () = + Alcotest.(check int) "filter" 2 + (eval_to_int "len(filter(\\x: Int -> x > 1, [1, 2, 3]))") + +let test_builtin_fold () = + Alcotest.(check int) "fold" 6 + (eval_to_int "fold(\\acc: Int, x: Int -> acc + x, 0, [1, 2, 3])") + +(* ========== Test Suite ========== *) + +let tests = [ + (* Literals *) + ("int literal", `Quick, test_int_literal); + ("negative int", `Quick, test_negative_int); + ("bool literal", `Quick, test_bool_literal); + ("string literal", `Quick, test_string_literal); + + (* Arithmetic *) + ("addition", `Quick, test_addition); + ("subtraction", `Quick, test_subtraction); + ("multiplication", `Quick, test_multiplication); + ("division", `Quick, test_division); + ("modulo", `Quick, test_modulo); + ("precedence", `Quick, test_precedence); + ("complex expr", `Quick, test_complex_expr); + + (* Comparison *) + ("equality", `Quick, test_equality); + ("inequality", `Quick, test_inequality); + ("less than", `Quick, test_less_than); + ("greater than", `Quick, test_greater_than); + + (* Logical *) + ("and", `Quick, test_and); + ("or", `Quick, test_or); + ("not", `Quick, test_not); + + (* Let bindings *) + ("let binding", `Quick, test_let_binding); + ("let shadowing", `Quick, test_let_shadowing); + ("let in expr", `Quick, test_let_in_expr); + + (* If expressions *) + ("if true", `Quick, test_if_true); + ("if false", `Quick, test_if_false); + ("nested if", `Quick, test_nested_if); + + (* Functions *) + ("function def", `Quick, test_function_def); + ("recursive fn", `Quick, test_recursive_fn); + ("higher order", `Quick, test_higher_order); + ("lambda", `Quick, test_lambda); + + (* Data structures *) + ("tuple", `Quick, test_tuple); + ("array index", `Quick, test_array_index); + ("array len", `Quick, test_array_len); + ("record", `Quick, test_record); + ("record shorthand", `Quick, test_record_shorthand); + + (* Match *) + ("match int", `Quick, test_match_int); + ("match tuple", `Quick, test_match_tuple); + + (* Loops *) + ("while loop", `Quick, test_while_loop); + ("for loop", `Quick, test_for_loop); + + (* Builtins *) + ("builtin range", `Quick, test_builtin_range); + ("builtin map", `Quick, test_builtin_map); + ("builtin filter", `Quick, test_builtin_filter); + ("builtin fold", `Quick, test_builtin_fold); +]