diff --git a/src/HashMap.mo b/src/HashMap.mo index 6e5ad9de..f76f99cf 100644 --- a/src/HashMap.mo +++ b/src/HashMap.mo @@ -17,12 +17,16 @@ import A "Array"; import Hash "Hash"; import Iter "Iter"; import AssocList "AssocList"; +import Nat32 "Nat32"; module { + // hash field avoids re-hashing the key when the array grows. + type Key = (Hash.Hash, K); + // key-val list type - type KVs = AssocList.AssocList; + type KVs = AssocList.AssocList, V>; /// An imperative HashMap with a minimal object-oriented interface. /// Maps keys of type `K` to values of type `V`. @@ -41,6 +45,10 @@ module { /// exist. public func delete(k : K) = ignore remove(k); + func keyHash_(k : K) : Key = (keyHash(k), k); + + func keyHashEq(k1 : Key, k2 : Key) : Bool { k1.0 == k2.0 and keyEq(k1.1, k2.1) }; + /// Removes the entry with the key `k` and returns the associated value if it /// existed or `null` otherwise. public func remove(k : K) : ?V { @@ -48,7 +56,7 @@ module { if (m > 0) { let h = Prim.nat32ToNat(keyHash(k)); let pos = h % m; - let (kvs2, ov) = AssocList.replace(table[pos], k, keyEq, null); + let (kvs2, ov) = AssocList.replace, V>(table[pos], keyHash_(k), keyHashEq, null); table[pos] := kvs2; switch(ov){ case null { }; @@ -66,7 +74,7 @@ module { let h = Prim.nat32ToNat(keyHash(k)); let m = table.size(); let v = if (m > 0) { - AssocList.find(table[h % m], k, keyEq) + AssocList.find, V>(table[h % m], keyHash_(k), keyHashEq) } else { null }; @@ -97,8 +105,7 @@ module { switch kvs { case null { break moveKeyVals }; case (?((k, v), kvsTail)) { - let h = Prim.nat32ToNat(keyHash(k)); - let pos2 = h % table2.size(); + let pos2 = Nat32.toNat(k.0) % table2.size(); // critical: uses saved hash. no re-hash. table2[pos2] := ?((k,v), table2[pos2]); kvs := kvsTail; }; @@ -109,7 +116,7 @@ module { }; let h = Prim.nat32ToNat(keyHash(k)); let pos = h % table.size(); - let (kvs2, ov) = AssocList.replace(table[pos], k, keyEq, ?v); + let (kvs2, ov) = AssocList.replace, V>(table[pos], keyHash_(k), keyHashEq, ?v); table[pos] := kvs2; switch(ov){ case null { _count += 1 }; @@ -140,7 +147,7 @@ module { switch kvs { case (?(kv, kvs2)) { kvs := kvs2; - ?kv + ?(kv.0.1, kv.1) }; case null { if (nextTablePos < table.size()) {