From 6e74ba5b1ad894adf1cadf14eccd005b0678eea4 Mon Sep 17 00:00:00 2001 From: Spotandjake Date: Sat, 28 Mar 2026 18:55:50 -0400 Subject: [PATCH] feat(runtime): Move `Memory.compare` from compiler to runtime This pr removes our `@wasm.memory_compare` primitive in favour of implementing it in the runtime. The reason I am making this change, is as we move away from linear memory and towards wasm gc it isn't as important that this function is in the compiler. Removing it from compcore lowers our maintance burden and moving it to grain makes it easier to manage. The bigger motivator is that our primitives tend to represent either low level things that aren't practical todo in the language itself or things with a direct wasm instruction and this fits neither of those, making the runtime a better home. Similar to how we handle all of the gc array comparisons once that pr lands. I'm not sure if we actually want to go forward with this change or not so i'm really just putting this pr in the case we do. Closes: #2375 --- compiler/src/codegen/compcore.re | 96 --------------------- compiler/src/codegen/mashtree.re | 1 - compiler/src/middle_end/analyze_purity.re | 1 - compiler/src/middle_end/anftree.re | 1 - compiler/src/middle_end/anftree.rei | 1 - compiler/src/parsing/parsetree.re | 1 - compiler/src/typed/translprim.re | 2 - compiler/src/typed/typecore.re | 9 -- compiler/src/typed/typedtree.re | 1 - compiler/src/typed/typedtree.rei | 1 - compiler/test/runtime/unsafe/memory.test.gr | 84 ++++++++++++++++++ compiler/test/suites/runtime.re | 1 + stdlib/runtime/unsafe/memory.gr | 10 ++- 13 files changed, 94 insertions(+), 115 deletions(-) create mode 100644 compiler/test/runtime/unsafe/memory.test.gr diff --git a/compiler/src/codegen/compcore.re b/compiler/src/codegen/compcore.re index d2cef90bca..3551c4e40b 100644 --- a/compiler/src/codegen/compcore.re +++ b/compiler/src/codegen/compcore.re @@ -2066,102 +2066,6 @@ let compile_primn = (wasm_mod, env: codegen_env, p, args): Expression.t => { const_void(wasm_mod), ], ) - | WasmMemoryCompare => - let lbl = gensym_label("memory_compare"); - let loop_lbl = gensym_label("memory_compare_loop"); - let set_ptr1 = set_swap(~ty=WasmValue(WasmI32), wasm_mod, env, 0); - let set_ptr2 = set_swap(~ty=WasmValue(WasmI32), wasm_mod, env, 1); - let set_count = set_swap(~ty=WasmValue(WasmI32), wasm_mod, env, 2); - let get_ptr1 = () => get_swap(~ty=WasmValue(WasmI32), wasm_mod, env, 0); - let get_ptr2 = () => get_swap(~ty=WasmValue(WasmI32), wasm_mod, env, 1); - let get_count = () => get_swap(~ty=WasmValue(WasmI32), wasm_mod, env, 2); - Expression.Block.make( - wasm_mod, - lbl, - [ - set_ptr1(compile_imm(wasm_mod, env, List.nth(args, 0))), - set_ptr2(compile_imm(wasm_mod, env, List.nth(args, 1))), - set_count(compile_imm(wasm_mod, env, List.nth(args, 2))), - Expression.Loop.make( - wasm_mod, - loop_lbl, - Expression.Block.make( - wasm_mod, - gensym_label("memory_compare_loop_inner"), - [ - Expression.Drop.make(wasm_mod) @@ - Expression.Break.make( - wasm_mod, - lbl, - Expression.Unary.make(wasm_mod, Op.eq_z_int32, get_count()), - Expression.Const.make(wasm_mod, const_int32(0)), - ), - Expression.If.make( - wasm_mod, - Expression.Binary.make( - wasm_mod, - Op.ne_int32, - load(~sz=1, ~signed=false, wasm_mod, get_ptr1()), - load(~sz=1, ~signed=false, wasm_mod, get_ptr2()), - ), - Expression.Break.make( - wasm_mod, - lbl, - Expression.Null.make(), - Expression.Select.make( - wasm_mod, - Expression.Binary.make( - wasm_mod, - Op.lt_u_int32, - load(~sz=1, ~signed=false, wasm_mod, get_ptr1()), - load(~sz=1, ~signed=false, wasm_mod, get_ptr2()), - ), - Expression.Const.make(wasm_mod, const_int32(-1)), - Expression.Const.make(wasm_mod, const_int32(1)), - ), - ), - Expression.Block.make( - wasm_mod, - gensym_label("memory_compare_loop_incr"), - [ - set_ptr1( - Expression.Binary.make( - wasm_mod, - Op.add_int32, - get_ptr1(), - Expression.Const.make(wasm_mod, const_int32(1)), - ), - ), - set_ptr2( - Expression.Binary.make( - wasm_mod, - Op.add_int32, - get_ptr2(), - Expression.Const.make(wasm_mod, const_int32(1)), - ), - ), - set_count( - Expression.Binary.make( - wasm_mod, - Op.sub_int32, - get_count(), - Expression.Const.make(wasm_mod, const_int32(1)), - ), - ), - Expression.Break.make( - wasm_mod, - loop_lbl, - Expression.Null.make(), - Expression.Null.make(), - ), - ], - ), - ), - ], - ), - ), - ], - ); | WasmRefArraySet({array_type}) => let array_type = get_array_type(~env, array_type); Expression.Block.make( diff --git a/compiler/src/codegen/mashtree.re b/compiler/src/codegen/mashtree.re index 671ea066ad..30cc710cd8 100644 --- a/compiler/src/codegen/mashtree.re +++ b/compiler/src/codegen/mashtree.re @@ -316,7 +316,6 @@ type primn = | WasmStoreF64 | WasmMemoryCopy | WasmMemoryFill - | WasmMemoryCompare | WasmRefArraySet({array_type: wasm_array_type}) | WasmRefArrayCopy({array_type: wasm_array_type}) | WasmRefArrayFill({array_type: wasm_array_type}) diff --git a/compiler/src/middle_end/analyze_purity.re b/compiler/src/middle_end/analyze_purity.re index 128e149214..221d6f7cec 100644 --- a/compiler/src/middle_end/analyze_purity.re +++ b/compiler/src/middle_end/analyze_purity.re @@ -132,7 +132,6 @@ module PurityArg: Anf_iterator.IterArgument = { WasmStoreI32(_) | WasmStoreI64(_) | WasmStoreF32 | WasmStoreF64 | WasmMemoryCopy | WasmMemoryFill | - WasmMemoryCompare | WasmRefArraySet(_) | WasmRefArrayCopy(_) | WasmRefArrayFill(_), diff --git a/compiler/src/middle_end/anftree.re b/compiler/src/middle_end/anftree.re index 9516479805..5c06dfbced 100644 --- a/compiler/src/middle_end/anftree.re +++ b/compiler/src/middle_end/anftree.re @@ -305,7 +305,6 @@ type primn = | WasmStoreF64 | WasmMemoryCopy | WasmMemoryFill - | WasmMemoryCompare | WasmRefArraySet({array_type: wasm_array_type}) | WasmRefArrayCopy({array_type: wasm_array_type}) | WasmRefArrayFill({array_type: wasm_array_type}) diff --git a/compiler/src/middle_end/anftree.rei b/compiler/src/middle_end/anftree.rei index b7a073a99e..ac0431a02f 100644 --- a/compiler/src/middle_end/anftree.rei +++ b/compiler/src/middle_end/anftree.rei @@ -306,7 +306,6 @@ type primn = | WasmStoreF64 | WasmMemoryCopy | WasmMemoryFill - | WasmMemoryCompare | WasmRefArraySet({array_type: wasm_array_type}) | WasmRefArrayCopy({array_type: wasm_array_type}) | WasmRefArrayFill({array_type: wasm_array_type}) diff --git a/compiler/src/parsing/parsetree.re b/compiler/src/parsing/parsetree.re index 3fd18e5291..30170797f3 100644 --- a/compiler/src/parsing/parsetree.re +++ b/compiler/src/parsing/parsetree.re @@ -497,7 +497,6 @@ type primn = | WasmStoreF64 | WasmMemoryCopy | WasmMemoryFill - | WasmMemoryCompare | WasmRefArraySet({array_type: wasm_array_type}) | WasmRefArrayCopy({array_type: wasm_array_type}) | WasmRefArrayFill({array_type: wasm_array_type}) diff --git a/compiler/src/typed/translprim.re b/compiler/src/typed/translprim.re index 398d769240..0306631985 100644 --- a/compiler/src/typed/translprim.re +++ b/compiler/src/typed/translprim.re @@ -1546,7 +1546,6 @@ let prim_map = ("@wasm.memory_size", Primitive0(WasmMemorySize)), ("@wasm.memory_copy", PrimitiveN(WasmMemoryCopy)), ("@wasm.memory_fill", PrimitiveN(WasmMemoryFill)), - ("@wasm.memory_compare", PrimitiveN(WasmMemoryCompare)), ("@wasm.ref_array_len", Primitive1(WasmRefArrayLen)), ( "@wasm.ref_array_i8_get_s", @@ -1699,7 +1698,6 @@ let transl_prim = (env, desc) => { | WasmStoreF64 | WasmMemoryCopy | WasmMemoryFill - | WasmMemoryCompare | WasmRefArraySet(_) => ( [lambda_arg(pat_a), lambda_arg(pat_b), lambda_arg(pat_c)], [id_a, id_b, id_c], diff --git a/compiler/src/typed/typecore.re b/compiler/src/typed/typecore.re index b01ff56e1e..ac97533bc3 100644 --- a/compiler/src/typed/typecore.re +++ b/compiler/src/typed/typecore.re @@ -504,15 +504,6 @@ let primn_type = ], Builtin_types.type_void, ) - | WasmMemoryCompare => - prim_type( - [ - ("ptr1", Builtin_types.type_wasmi32), - ("ptr2", Builtin_types.type_wasmi32), - ("length", Builtin_types.type_wasmi32), - ], - Builtin_types.type_wasmi32, - ) | WasmRefArraySet({array_type: Wasm_packed_i8}) => prim_type( [ diff --git a/compiler/src/typed/typedtree.re b/compiler/src/typed/typedtree.re index 1c3913f9ac..82b08e8969 100644 --- a/compiler/src/typed/typedtree.re +++ b/compiler/src/typed/typedtree.re @@ -328,7 +328,6 @@ type primn = | WasmStoreF64 | WasmMemoryCopy | WasmMemoryFill - | WasmMemoryCompare | WasmRefArraySet({array_type: wasm_array_type}) | WasmRefArrayCopy({array_type: wasm_array_type}) | WasmRefArrayFill({array_type: wasm_array_type}) diff --git a/compiler/src/typed/typedtree.rei b/compiler/src/typed/typedtree.rei index fb8ad96d43..194ddb2562 100644 --- a/compiler/src/typed/typedtree.rei +++ b/compiler/src/typed/typedtree.rei @@ -327,7 +327,6 @@ type primn = | WasmStoreF64 | WasmMemoryCopy | WasmMemoryFill - | WasmMemoryCompare | WasmRefArraySet({array_type: wasm_array_type}) | WasmRefArrayCopy({array_type: wasm_array_type}) | WasmRefArrayFill({array_type: wasm_array_type}) diff --git a/compiler/test/runtime/unsafe/memory.test.gr b/compiler/test/runtime/unsafe/memory.test.gr new file mode 100644 index 0000000000..90643605c7 --- /dev/null +++ b/compiler/test/runtime/unsafe/memory.test.gr @@ -0,0 +1,84 @@ +module MemoryTest + +from "runtime/unsafe/wasmi32" include WasmI32 +from "runtime/unsafe/memory" include Memory +from "runtime/malloc" include Malloc + +from "runtime/debugPrint" include DebugPrint + +// Memory.copy +@unsafe +let test = () => { + use WasmI32.{ (+), (-), (==), ltU as (<), gtU as (>) } + let length = 10n + let section1 = Malloc.malloc(length) + let section2 = Malloc.malloc(length) + // Clear both sections byte by byte + for (let mut i = 0n; i < length; i += 1n) { + WasmI32.store8(section1, 0n, i) + WasmI32.store8(section2, 0n, i) + } + // Set section1 to 1,2,...length + for (let mut i = 0n; i < length; i += 1n) { + WasmI32.store8(section1, i, i) + } + // Copy section1 to section2 + Memory.copy(section2, section1, length) + // Verify the copy was successful + for (let mut i = 0n; i < length; i += 1n) { + assert WasmI32.load8U(section1, i) == WasmI32.load8U(section2, i) + } + // Verify that overlapping regions are handled correctly by copying section1 to itself with an offset + let shift = 2n + Memory.copy(section1, section1 + shift, length - shift) + for (let mut i = 0n; i < length - shift; i += 1n) { + assert WasmI32.load8U(section1, i) == i + shift + } +} +test() + +// Memory.fill +@unsafe +let test = () => { + use WasmI32.{ (+), (==), ltU as (<) } + let length = 10n + let section = Malloc.malloc(10n) + // Clear the section byte by byte + for (let mut i = 0n; i < length; i += 1n) { + WasmI32.store8(section, 0n, i) + } + // Fill the section with `255` + Memory.fill(section, 255n, length) + // Verify the fill was successful + for (let mut i = 0n; i < length; i += 1n) { + assert WasmI32.load8U(section, i) == 255n + } +} +test() + +// Memory.compare +@unsafe +let test = () => { + use WasmI32.{ (-), (==), (<), (>) } + let length = 10n + let section1 = Malloc.malloc(length) + let section2 = Malloc.malloc(length) + // Equal regions + Memory.fill(section1, 0n, length) + Memory.fill(section2, 0n, length) + assert Memory.compare(section1, section2, 0n) == 0n + // First region less than second + Memory.fill(section1, 0n, length) + Memory.fill(section2, 1n, length) + assert Memory.compare(section1, section2, length) < 0n + // First region greater than second + Memory.fill(section1, 1n, length) + Memory.fill(section2, 0n, length) + assert Memory.compare(section1, section2, length) > 0n + // Regions differ at the last byte + Memory.fill(section1, 0n, length) + WasmI32.store8(section1, 255n, length - 1n) + Memory.fill(section2, 0n, length) + assert Memory.compare(section1, section2, length) > 0n +} +test() diff --git a/compiler/test/suites/runtime.re b/compiler/test/suites/runtime.re index 35cc3ec8ce..0d6baafc7a 100644 --- a/compiler/test/suites/runtime.re +++ b/compiler/test/suites/runtime.re @@ -13,4 +13,5 @@ describe("runtime", ({test, testSkip}) => { assertRuntime("unsafe/wasmi32.test"); assertRuntime("unsafe/wasmi64.test"); assertRuntime("unsafe/wasmref.test"); + assertRuntime("unsafe/memory.test"); }); diff --git a/stdlib/runtime/unsafe/memory.gr b/stdlib/runtime/unsafe/memory.gr index 5c02de3bd8..c617f9a869 100644 --- a/stdlib/runtime/unsafe/memory.gr +++ b/stdlib/runtime/unsafe/memory.gr @@ -33,7 +33,15 @@ provide primitive fill = "@wasm.memory_fill" * * @returns `0` if the memory regions are equal, a negative value if the first region is less than the second, and a positive value if the first region is greater than the second. */ -provide primitive compare = "@wasm.memory_compare" +provide let compare = (ptr1: WasmI32, ptr2: WasmI32, length: WasmI32) => { + use WasmI32.{ (!=), (-) } + for (let mut i = 0n; i < length; i += 1n) { + let byte1 = WasmI32.load8U(ptr1, i) + let byte2 = WasmI32.load8U(ptr2, i) + if (byte1 != byte2) return byte1 - byte2 + } + return 0n +} /** * Copies data from a (array i8) to linear memory.