Przeglądaj źródła

Implement Go To Declaration (#1540)

- gives a list of contract methods that the given contract method
overrides.
-  only returns the methods belonging to the immediate parent contracts.
-  handles multiple inheritance.

Signed-off-by: Govardhan G D <chioni1620@gmail.com>
Govardhan G D 2 lat temu
rodzic
commit
b972858ea7

+ 227 - 104
src/bin/languageserver/mod.rs

@@ -4,26 +4,33 @@ use itertools::Itertools;
 use num_traits::ToPrimitive;
 use rust_lapper::{Interval, Lapper};
 use serde_json::Value;
-use solang::sema::ast::{RetrieveType, StructType, Type};
 use solang::{
-    codegen::codegen,
-    codegen::{self, Expression},
+    codegen::{self, codegen, Expression},
     file_resolver::FileResolver,
     parse_and_resolve,
-    sema::{ast, builtin::get_prototype, symtable, tags::render},
+    sema::{
+        ast::{self, RetrieveType, StructType, Type},
+        builtin::get_prototype,
+        symtable,
+        tags::render,
+    },
     Target,
 };
 use solang_parser::pt;
-use std::{collections::HashMap, ffi::OsString, path::PathBuf};
+use std::{
+    collections::{HashMap, HashSet},
+    ffi::OsString,
+    path::PathBuf,
+};
 use tokio::sync::Mutex;
 use tower_lsp::{
     jsonrpc::{Error, ErrorCode, Result},
     lsp_types::{
         request::{
-            GotoImplementationParams, GotoImplementationResponse, GotoTypeDefinitionParams,
-            GotoTypeDefinitionResponse,
+            GotoDeclarationParams, GotoDeclarationResponse, GotoImplementationParams,
+            GotoImplementationResponse, GotoTypeDefinitionParams, GotoTypeDefinitionResponse,
         },
-        CompletionOptions, CompletionParams, CompletionResponse, Diagnostic,
+        CompletionOptions, CompletionParams, CompletionResponse, DeclarationCapability, Diagnostic,
         DiagnosticRelatedInformation, DiagnosticSeverity, DidChangeConfigurationParams,
         DidChangeTextDocumentParams, DidChangeWatchedFilesParams, DidChangeWorkspaceFoldersParams,
         DidCloseTextDocumentParams, DidOpenTextDocumentParams, DidSaveTextDocumentParams,
@@ -103,6 +110,8 @@ type ReferenceEntry = Interval<usize, DefinitionIndex>;
 type Implementations = HashMap<DefinitionIndex, Vec<DefinitionIndex>>;
 /// Stores types of code objects
 type Types = HashMap<DefinitionIndex, DefinitionIndex>;
+/// Stores all the functions that a given function overrides
+type Declarations = HashMap<DefinitionIndex, Vec<DefinitionIndex>>;
 
 /// Stores information used by language server for every opened file
 struct Files {
@@ -127,13 +136,24 @@ struct FileCache {
 /// Each field stores *some information* about a code object. The code object is uniquely identified by its `DefinitionIndex`.
 /// * `definitions` maps `DefinitionIndex` of a code object to its source code location where it is defined.
 /// * `types` maps the `DefinitionIndex` of a code object to that of its type.
+/// * `declarations` maps the `DefinitionIndex` of a `Contract` method to a list of methods that it overrides. The overridden methods belong to the parent `Contract`s
 /// * `implementations` maps the `DefinitionIndex` of a `Contract` to the `DefinitionIndex`s of methods defined as part of the `Contract`.
 struct GlobalCache {
     definitions: Definitions,
     types: Types,
+    declarations: Declarations,
     implementations: Implementations,
 }
 
+impl GlobalCache {
+    fn extend(&mut self, other: Self) {
+        self.definitions.extend(other.definitions);
+        self.types.extend(other.types);
+        self.declarations.extend(other.declarations);
+        self.implementations.extend(other.implementations);
+    }
+}
+
 // The language server currently stores some of the data grouped by the file to which the data belongs (Files struct).
 // Other data (Definitions) is not grouped by file due to problems faced during cleanup,
 // but is stored as a "global" field which is common to all files.
@@ -197,6 +217,7 @@ pub async fn start_server(language_args: &LanguageServerCommand) -> ! {
         global_cache: Mutex::new(GlobalCache {
             definitions: HashMap::new(),
             types: HashMap::new(),
+            declarations: HashMap::new(),
             implementations: HashMap::new(),
         }),
     });
@@ -281,19 +302,17 @@ impl SolangServer {
 
             let res = self.client.publish_diagnostics(uri, diags, None);
 
-            let (caches, definitions, types, implementations) = Builder::build(&ns);
+            let (file_caches, global_cache) = Builder::new(&ns).build();
 
             let mut files = self.files.lock().await;
-            for (f, c) in ns.files.iter().zip(caches.into_iter()) {
+            for (f, c) in ns.files.iter().zip(file_caches.into_iter()) {
                 if f.cache_no.is_some() {
                     files.caches.insert(f.path.clone(), c);
                 }
             }
 
             let mut gc = self.global_cache.lock().await;
-            gc.definitions.extend(definitions);
-            gc.types.extend(types);
-            gc.implementations.extend(implementations);
+            gc.extend(global_cache);
 
             res.await;
         }
@@ -337,13 +356,28 @@ struct Builder<'a> {
     references: Vec<(usize, ReferenceEntry)>,
 
     definitions: Definitions,
-    implementations: Implementations,
     types: Types,
+    declarations: Declarations,
+    implementations: Implementations,
 
     ns: &'a ast::Namespace,
 }
 
 impl<'a> Builder<'a> {
+    fn new(ns: &'a ast::Namespace) -> Self {
+        Self {
+            hovers: Vec::new(),
+            references: Vec::new(),
+
+            definitions: HashMap::new(),
+            types: HashMap::new(),
+            declarations: HashMap::new(),
+            implementations: HashMap::new(),
+
+            ns,
+        }
+    }
+
     // Constructs lookup table for the given statement by traversing the
     // statements and traversing inside the contents of the statements.
     fn statement(&mut self, stmt: &ast::Statement, symtab: &symtable::Symtable) {
@@ -1318,23 +1352,12 @@ impl<'a> Builder<'a> {
 
     /// Traverses namespace to extract information used later by the language server
     /// This includes hover messages, locations where code objects are declared and used
-    fn build(ns: &ast::Namespace) -> (Vec<FileCache>, Definitions, Types, Implementations) {
-        let mut builder = Builder {
-            hovers: Vec::new(),
-            references: Vec::new(),
-
-            definitions: HashMap::new(),
-            implementations: HashMap::new(),
-            types: HashMap::new(),
-
-            ns,
-        };
-
-        for (ei, enum_decl) in builder.ns.enums.iter().enumerate() {
+    fn build(mut self) -> (Vec<FileCache>, GlobalCache) {
+        for (ei, enum_decl) in self.ns.enums.iter().enumerate() {
             for (discriminant, (nam, loc)) in enum_decl.values.iter().enumerate() {
                 let file_no = loc.file_no();
-                let file = &ns.files[file_no];
-                builder.hovers.push((
+                let file = &self.ns.files[file_no];
+                self.hovers.push((
                     file_no,
                     HoverEntry {
                         start: loc.start(),
@@ -1350,17 +1373,15 @@ impl<'a> Builder<'a> {
                     def_path: file.path.clone(),
                     def_type: DefinitionType::Variant(ei, discriminant),
                 };
-                builder
-                    .definitions
-                    .insert(di.clone(), loc_to_range(loc, file));
+                self.definitions.insert(di.clone(), loc_to_range(loc, file));
 
                 let dt = DefinitionType::Enum(ei);
-                builder.types.insert(di, dt.into());
+                self.types.insert(di, dt.into());
             }
 
             let file_no = enum_decl.loc.file_no();
-            let file = &ns.files[file_no];
-            builder.hovers.push((
+            let file = &self.ns.files[file_no];
+            self.hovers.push((
                 file_no,
                 HoverEntry {
                     start: enum_decl.loc.start(),
@@ -1368,7 +1389,7 @@ impl<'a> Builder<'a> {
                     val: render(&enum_decl.tags[..]),
                 },
             ));
-            builder.definitions.insert(
+            self.definitions.insert(
                 DefinitionIndex {
                     def_path: file.path.clone(),
                     def_type: DefinitionType::Enum(ei),
@@ -1377,15 +1398,15 @@ impl<'a> Builder<'a> {
             );
         }
 
-        for (si, struct_decl) in builder.ns.structs.iter().enumerate() {
+        for (si, struct_decl) in self.ns.structs.iter().enumerate() {
             if let pt::Loc::File(_, start, _) = &struct_decl.loc {
                 for (fi, field) in struct_decl.fields.iter().enumerate() {
-                    builder.field(si, fi, field);
+                    self.field(si, fi, field);
                 }
 
                 let file_no = struct_decl.loc.file_no();
-                let file = &ns.files[file_no];
-                builder.hovers.push((
+                let file = &self.ns.files[file_no];
+                self.hovers.push((
                     file_no,
                     HoverEntry {
                         start: *start,
@@ -1393,7 +1414,7 @@ impl<'a> Builder<'a> {
                         val: render(&struct_decl.tags[..]),
                     },
                 ));
-                builder.definitions.insert(
+                self.definitions.insert(
                     DefinitionIndex {
                         def_path: file.path.clone(),
                         def_type: DefinitionType::Struct(si),
@@ -1403,26 +1424,26 @@ impl<'a> Builder<'a> {
             }
         }
 
-        for (i, func) in builder.ns.functions.iter().enumerate() {
+        for (i, func) in self.ns.functions.iter().enumerate() {
             if func.is_accessor || func.loc == pt::Loc::Builtin {
                 // accessor functions are synthetic; ignore them, all the locations are fake
                 continue;
             }
 
             if let Some(bump) = &func.annotations.bump {
-                builder.expression(&bump.1, &func.symtable);
+                self.expression(&bump.1, &func.symtable);
             }
 
             for seed in &func.annotations.seeds {
-                builder.expression(&seed.1, &func.symtable);
+                self.expression(&seed.1, &func.symtable);
             }
 
             if let Some(space) = &func.annotations.space {
-                builder.expression(&space.1, &func.symtable);
+                self.expression(&space.1, &func.symtable);
             }
 
             if let Some((loc, name)) = &func.annotations.payer {
-                builder.hovers.push((
+                self.hovers.push((
                     loc.file_no(),
                     HoverEntry {
                         start: loc.start(),
@@ -1433,33 +1454,32 @@ impl<'a> Builder<'a> {
             }
 
             for (i, param) in func.params.iter().enumerate() {
-                builder.hovers.push((
+                self.hovers.push((
                     param.loc.file_no(),
                     HoverEntry {
                         start: param.loc.start(),
                         stop: param.loc.exclusive_end(),
-                        val: builder.expanded_ty(&param.ty),
+                        val: self.expanded_ty(&param.ty),
                     },
                 ));
                 if let Some(Some(var_no)) = func.symtable.arguments.get(i) {
                     if let Some(id) = &param.id {
                         let file_no = id.loc.file_no();
-                        let file = &builder.ns.files[file_no];
+                        let file = &self.ns.files[file_no];
                         let di = DefinitionIndex {
                             def_path: file.path.clone(),
                             def_type: DefinitionType::Variable(*var_no),
                         };
-                        builder
-                            .definitions
+                        self.definitions
                             .insert(di.clone(), loc_to_range(&id.loc, file));
                         if let Some(dt) = get_type_definition(&param.ty) {
-                            builder.types.insert(di, dt.into());
+                            self.types.insert(di, dt.into());
                         }
                     }
                 }
                 if let Some(loc) = param.ty_loc {
                     if let Some(dt) = get_type_definition(&param.ty) {
-                        builder.references.push((
+                        self.references.push((
                             loc.file_no(),
                             ReferenceEntry {
                                 start: loc.start(),
@@ -1472,35 +1492,34 @@ impl<'a> Builder<'a> {
             }
 
             for (i, ret) in func.returns.iter().enumerate() {
-                builder.hovers.push((
+                self.hovers.push((
                     ret.loc.file_no(),
                     HoverEntry {
                         start: ret.loc.start(),
                         stop: ret.loc.exclusive_end(),
-                        val: builder.expanded_ty(&ret.ty),
+                        val: self.expanded_ty(&ret.ty),
                     },
                 ));
 
                 if let Some(id) = &ret.id {
                     if let Some(var_no) = func.symtable.returns.get(i) {
                         let file_no = id.loc.file_no();
-                        let file = &ns.files[file_no];
+                        let file = &self.ns.files[file_no];
                         let di = DefinitionIndex {
                             def_path: file.path.clone(),
                             def_type: DefinitionType::Variable(*var_no),
                         };
-                        builder
-                            .definitions
+                        self.definitions
                             .insert(di.clone(), loc_to_range(&id.loc, file));
                         if let Some(dt) = get_type_definition(&ret.ty) {
-                            builder.types.insert(di, dt.into());
+                            self.types.insert(di, dt.into());
                         }
                     }
                 }
 
                 if let Some(loc) = ret.ty_loc {
                     if let Some(dt) = get_type_definition(&ret.ty) {
-                        builder.references.push((
+                        self.references.push((
                             loc.file_no(),
                             ReferenceEntry {
                                 start: loc.start(),
@@ -1513,12 +1532,12 @@ impl<'a> Builder<'a> {
             }
 
             for stmt in &func.body {
-                builder.statement(stmt, &func.symtable);
+                self.statement(stmt, &func.symtable);
             }
 
             let file_no = func.loc.file_no();
-            let file = &ns.files[file_no];
-            builder.definitions.insert(
+            let file = &self.ns.files[file_no];
+            self.definitions.insert(
                 DefinitionIndex {
                     def_path: file.path.clone(),
                     def_type: DefinitionType::Function(i),
@@ -1527,26 +1546,26 @@ impl<'a> Builder<'a> {
             );
         }
 
-        for (i, constant) in builder.ns.constants.iter().enumerate() {
+        for (i, constant) in self.ns.constants.iter().enumerate() {
             let samptb = symtable::Symtable::new();
-            builder.contract_variable(constant, &samptb, None, i);
+            self.contract_variable(constant, &samptb, None, i);
         }
 
-        for (ci, contract) in builder.ns.contracts.iter().enumerate() {
+        for (ci, contract) in self.ns.contracts.iter().enumerate() {
             for base in &contract.bases {
                 let file_no = base.loc.file_no();
-                builder.hovers.push((
+                self.hovers.push((
                     file_no,
                     HoverEntry {
                         start: base.loc.start(),
                         stop: base.loc.exclusive_end(),
                         val: make_code_block(format!(
                             "contract {}",
-                            builder.ns.contracts[base.contract_no].name
+                            self.ns.contracts[base.contract_no].name
                         )),
                     },
                 ));
-                builder.references.push((
+                self.references.push((
                     file_no,
                     ReferenceEntry {
                         start: base.loc.start(),
@@ -1561,12 +1580,12 @@ impl<'a> Builder<'a> {
 
             for (i, variable) in contract.variables.iter().enumerate() {
                 let symtable = symtable::Symtable::new();
-                builder.contract_variable(variable, &symtable, Some(ci), i);
+                self.contract_variable(variable, &symtable, Some(ci), i);
             }
 
             let file_no = contract.loc.file_no();
-            let file = &ns.files[file_no];
-            builder.hovers.push((
+            let file = &self.ns.files[file_no];
+            self.hovers.push((
                 file_no,
                 HoverEntry {
                     start: contract.loc.start(),
@@ -1579,8 +1598,8 @@ impl<'a> Builder<'a> {
                 def_path: file.path.clone(),
                 def_type: DefinitionType::Contract(ci),
             };
-            builder
-                .definitions
+
+            self.definitions
                 .insert(cdi.clone(), loc_to_range(&contract.loc, file));
 
             let impls = contract
@@ -1591,17 +1610,68 @@ impl<'a> Builder<'a> {
                     def_type: DefinitionType::Function(*f),
                 })
                 .collect();
-            builder.implementations.insert(cdi, impls);
+
+            self.implementations.insert(cdi, impls);
+
+            let decls = contract
+                .virtual_functions
+                .iter()
+                .filter_map(|(_, indices)| {
+                    // due to the way the `indices` vector is populated during namespace creation,
+                    // the last element in the vector contains the overriding function that belongs to the current contract.
+                    let func = DefinitionIndex {
+                        def_path: file.path.clone(),
+                        // `unwrap` is alright here as the `indices` vector is guaranteed to have at least 1 element
+                        // the vector is always initialised with one initial element
+                        // and the elements in the vector are never removed during namespace construction
+                        def_type: DefinitionType::Function(indices.last().copied().unwrap()),
+                    };
+
+                    // get all the functions overridden by the current function
+                    let all_decls: HashSet<usize> = HashSet::from_iter(indices.iter().copied());
+
+                    // choose the overridden functions that belong to the parent contracts
+                    // due to multiple inheritance, a contract can have multiple parents
+                    let parent_decls = contract
+                        .bases
+                        .iter()
+                        .map(|b| {
+                            let p = &self.ns.contracts[b.contract_no];
+                            HashSet::from_iter(p.functions.iter().copied())
+                                .intersection(&all_decls)
+                                .copied()
+                                .collect::<HashSet<usize>>()
+                        })
+                        .reduce(|acc, e| acc.union(&e).copied().collect());
+
+                    // get the `DefinitionIndex`s of the overridden funcions
+                    parent_decls.map(|parent_decls| {
+                        let decls = parent_decls
+                            .iter()
+                            .map(|&i| {
+                                let loc = self.ns.functions[i].loc;
+                                DefinitionIndex {
+                                    def_path: self.ns.files[loc.file_no()].path.clone(),
+                                    def_type: DefinitionType::Function(i),
+                                }
+                            })
+                            .collect::<Vec<_>>();
+
+                        (func, decls)
+                    })
+                });
+
+            self.declarations.extend(decls);
         }
 
-        for (ei, event) in builder.ns.events.iter().enumerate() {
+        for (ei, event) in self.ns.events.iter().enumerate() {
             for (fi, field) in event.fields.iter().enumerate() {
-                builder.field(ei, fi, field);
+                self.field(ei, fi, field);
             }
 
             let file_no = event.loc.file_no();
-            let file = &ns.files[file_no];
-            builder.hovers.push((
+            let file = &self.ns.files[file_no];
+            self.hovers.push((
                 file_no,
                 HoverEntry {
                     start: event.loc.start(),
@@ -1610,7 +1680,7 @@ impl<'a> Builder<'a> {
                 },
             ));
 
-            builder.definitions.insert(
+            self.definitions.insert(
                 DefinitionIndex {
                     def_path: file.path.clone(),
                     def_type: DefinitionType::Event(ei),
@@ -1619,37 +1689,42 @@ impl<'a> Builder<'a> {
             );
         }
 
-        for lookup in &mut builder.hovers {
-            if let Some(msg) = builder.ns.hover_overrides.get(&pt::Loc::File(
-                lookup.0,
-                lookup.1.start,
-                lookup.1.stop,
-            )) {
+        for lookup in &mut self.hovers {
+            if let Some(msg) =
+                self.ns
+                    .hover_overrides
+                    .get(&pt::Loc::File(lookup.0, lookup.1.start, lookup.1.stop))
+            {
                 lookup.1.val = msg.clone();
             }
         }
 
-        let defs_to_files = builder
+        // `defs_to_files` and `defs_to_file_nos` are used to insert the correct filepath where a code object is defined.
+        // previously, a dummy path was filled.
+        // In a single namespace, there can't be two (or more) code objects with a given `DefinitionType`.
+        // So, there exists a one-to-one mapping between `DefinitionIndex` and `DefinitionType` when we are dealing with just one namespace.
+        let defs_to_files = self
             .definitions
             .keys()
             .map(|key| (key.def_type.clone(), key.def_path.clone()))
             .collect::<HashMap<DefinitionType, PathBuf>>();
 
-        let defs_to_file_nos = ns
+        let defs_to_file_nos = self
+            .ns
             .files
             .iter()
             .enumerate()
             .map(|(i, f)| (f.path.clone(), i))
             .collect::<HashMap<PathBuf, usize>>();
 
-        for val in builder.types.values_mut() {
+        for val in self.types.values_mut() {
             val.def_path = defs_to_files[&val.def_type].clone();
         }
 
-        for (di, range) in &builder.definitions {
+        for (di, range) in &self.definitions {
             let file_no = defs_to_file_nos[&di.def_path];
-            let file = &ns.files[file_no];
-            builder.references.push((
+            let file = &self.ns.files[file_no];
+            self.references.push((
                 file_no,
                 ReferenceEntry {
                     start: file
@@ -1663,23 +1738,24 @@ impl<'a> Builder<'a> {
             ));
         }
 
-        let caches = ns
+        let file_caches = self
+            .ns
             .files
             .iter()
             .enumerate()
             .map(|(i, f)| FileCache {
                 file: f.clone(),
+                // get `hovers` that belong to the current file
                 hovers: Lapper::new(
-                    builder
-                        .hovers
+                    self.hovers
                         .iter()
                         .filter(|h| h.0 == i)
                         .map(|(_, i)| i.clone())
                         .collect(),
                 ),
+                // get `references` that belong to the current file
                 references: Lapper::new(
-                    builder
-                        .references
+                    self.references
                         .iter()
                         .filter(|h| h.0 == i)
                         .map(|(_, i)| {
@@ -1692,12 +1768,14 @@ impl<'a> Builder<'a> {
             })
             .collect();
 
-        (
-            caches,
-            builder.definitions,
-            builder.types,
-            builder.implementations,
-        )
+        let global_cache = GlobalCache {
+            definitions: self.definitions,
+            types: self.types,
+            declarations: self.declarations,
+            implementations: self.implementations,
+        };
+
+        (file_caches, global_cache)
     }
 
     /// Render the type with struct/enum fields expanded
@@ -1781,6 +1859,7 @@ impl LanguageServer for SolangServer {
                 definition_provider: Some(OneOf::Left(true)),
                 type_definition_provider: Some(TypeDefinitionProviderCapability::Simple(true)),
                 implementation_provider: Some(ImplementationProviderCapability::Simple(true)),
+                declaration_provider: Some(DeclarationCapability::Simple(true)),
                 references_provider: Some(OneOf::Left(true)),
                 rename_provider: Some(OneOf::Left(true)),
                 ..ServerCapabilities::default()
@@ -2079,6 +2158,50 @@ impl LanguageServer for SolangServer {
         Ok(impls)
     }
 
+    /// Called when "Go to Declaration" is called by the user on the client side.
+    ///
+    /// Expected to return a list (possibly empty) of methods that the given method overrides.
+    /// Only the methods belonging to the immediate parent contracts (due to multiple inheritance, there can be more than one parent) are to be returned.
+    ///
+    /// ### Arguments
+    /// * `GotoDeclarationParams` provides the source code location (filename, line number, column number) of the code object for which the request was made.
+    ///
+    /// ### Edge cases
+    /// * Returns `Err` when an invalid file path is received.
+    /// * Returns `Ok(None)` when the location passed in the arguments doesn't belong to a contract method defined in user code.
+    async fn goto_declaration(
+        &self,
+        params: GotoDeclarationParams,
+    ) -> Result<Option<GotoDeclarationResponse>> {
+        // fetch the `DefinitionIndex` of the code object in question
+        let Some(reference) = self.get_reference_from_params(params).await? else {
+            return Ok(None);
+        };
+
+        let gc = self.global_cache.lock().await;
+
+        // get a list of `DefinitionIndex`s of overridden functions from parent contracts
+        let decls = gc.declarations.get(&reference);
+
+        // get a list of locations in source code where the overridden functions are present
+        let locations = decls
+            .map(|decls| {
+                decls
+                    .iter()
+                    .filter_map(|di| {
+                        let path = &di.def_path;
+                        gc.definitions.get(di).map(|range| {
+                            let uri = Url::from_file_path(path).unwrap();
+                            Location { uri, range: *range }
+                        })
+                    })
+                    .collect()
+            })
+            .map(GotoImplementationResponse::Array);
+
+        Ok(locations)
+    }
+
     /// Called when "Go to References" is called by the user on the client side.
     ///
     /// Expected to return a list of locations in the source code where the given code-object is used.
@@ -2167,7 +2290,7 @@ impl LanguageServer for SolangServer {
         let new_text = params.new_name;
 
         // create `TextEdit` instances that represent the changes to be made for every occurrence of the old symbol
-        // these `TextEdit` objects are then grouped into separate list per source file to which they bolong
+        // these `TextEdit` objects are then grouped into separate list per source file to which they belong
         let caches = &self.files.lock().await.caches;
         let ws = caches
             .iter()

+ 6 - 0
src/codegen/cfg.rs

@@ -1443,6 +1443,7 @@ fn is_there_virtual_function(
     if func.ty == pt::FunctionTy::Receive {
         // if there is a virtual receive function, and it's not this one, ignore it
         if let Some(receive) = ns.contracts[contract_no].virtual_functions.get("@receive") {
+            let receive = receive.last().unwrap();
             if Some(*receive) != function_no {
                 return true;
             }
@@ -1452,6 +1453,7 @@ fn is_there_virtual_function(
     if func.ty == pt::FunctionTy::Fallback {
         // if there is a virtual fallback function, and it's not this one, ignore it
         if let Some(fallback) = ns.contracts[contract_no].virtual_functions.get("@fallback") {
+            let fallback = fallback.last().unwrap();
             if Some(*fallback) != function_no {
                 return true;
             }
@@ -1537,6 +1539,9 @@ fn resolve_modifier_call<'a>(
             // is it a virtual function call
             let function_no = if let Some(signature) = signature {
                 contract.virtual_functions[signature]
+                    .last()
+                    .copied()
+                    .unwrap()
             } else {
                 *function_no
             };
@@ -2130,6 +2135,7 @@ impl Namespace {
             && self.contracts[contract_no]
                 .virtual_functions
                 .get(&func.signature)
+                .and_then(|v| v.last())
                 != function_no.as_ref()
         {
             return false;

+ 6 - 1
src/codegen/expression.rs

@@ -490,7 +490,9 @@ pub fn expression(
             ..
         } => {
             let function_no = if let Some(signature) = signature {
-                &ns.contracts[contract_no].virtual_functions[signature]
+                ns.contracts[contract_no].virtual_functions[signature]
+                    .last()
+                    .unwrap()
             } else {
                 function_no
             };
@@ -2696,6 +2698,9 @@ pub fn emit_function_call(
 
                 let function_no = if let Some(signature) = signature {
                     ns.contracts[caller_contract_no].virtual_functions[signature]
+                        .last()
+                        .copied()
+                        .unwrap()
                 } else {
                     *function_no
                 };

+ 4 - 1
src/sema/ast.rs

@@ -755,7 +755,10 @@ pub struct Contract {
     pub fixed_layout_size: BigInt,
     pub functions: Vec<usize>,
     pub all_functions: BTreeMap<usize, usize>,
-    pub virtual_functions: HashMap<String, usize>,
+    /// maps the name of virtual functions to a vector of overriden functions.
+    /// Each time a virtual function is overriden, there will be an entry pushed to the vector. The last
+    /// element represents the current overriding function - there will be at least one entry in this vector.
+    pub virtual_functions: HashMap<String, Vec<usize>>,
     pub yul_functions: Vec<usize>,
     pub variables: Vec<Variable>,
     /// List of contracts this contract instantiates

+ 3 - 1
src/sema/contracts.rs

@@ -598,7 +598,9 @@ fn check_inheritance(contract_no: usize, ns: &mut ast::Namespace) {
             if cur.is_override.is_some() || cur.is_virtual {
                 ns.contracts[contract_no]
                     .virtual_functions
-                    .insert(signature, function_no);
+                    .entry(signature)
+                    .or_insert_with(Vec::new)
+                    .push(function_no); // there is always at least 1 element in the vector
             }
 
             ns.contracts[contract_no]

+ 3 - 0
src/sema/mutability.rs

@@ -186,6 +186,9 @@ fn check_mutability(func: &Function, ns: &Namespace) -> Vec<Diagnostic> {
             {
                 let function_no = if let Some(signature) = signature {
                     state.ns.contracts[contract_no].virtual_functions[signature]
+                        .last()
+                        .copied()
+                        .unwrap()
                 } else {
                     *function_no
                 };

+ 99 - 39
vscode/src/test/suite/extension.test.ts

@@ -83,6 +83,13 @@ suite('Extension Test Suite', function () {
     await testtypedefs(typedefdoc1);
   });
 
+  // Tests for goto-declaration
+  this.timeout(20000);
+  const declsdoc1 = getDocUri('impls.sol');
+  test('Testing for GoToDeclaration', async () => {
+    await testdecls(declsdoc1);
+  });
+
   // Tests for goto-impls
   this.timeout(20000);
   const implsdoc1 = getDocUri('impls.sol');
@@ -119,8 +126,8 @@ async function testdefs(docUri: vscode.Uri) {
     'vscode.executeDefinitionProvider',
     docUri,
     pos1
-  )) as vscode.Definition[];
-  const loc1 = actualdef1[0] as vscode.Location;
+  )) as vscode.Location[];
+  const loc1 = actualdef1[0];
   assert.strictEqual(loc1.range.start.line, 27);
   assert.strictEqual(loc1.range.start.character, 24);
   assert.strictEqual(loc1.range.end.line, 27);
@@ -132,8 +139,8 @@ async function testdefs(docUri: vscode.Uri) {
     'vscode.executeDefinitionProvider',
     docUri,
     pos2
-  )) as vscode.Definition[];
-  const loc2 = actualdef2[0] as vscode.Location;
+  )) as vscode.Location[];
+  const loc2 = actualdef2[0];
   assert.strictEqual(loc2.range.start.line, 27);
   assert.strictEqual(loc2.range.start.character, 50);
   assert.strictEqual(loc2.range.end.line, 27);
@@ -145,8 +152,8 @@ async function testdefs(docUri: vscode.Uri) {
     'vscode.executeDefinitionProvider',
     docUri,
     pos3
-  )) as vscode.Definition[];
-  const loc3 = actualdef3[0] as vscode.Location;
+  )) as vscode.Location[];
+  const loc3 = actualdef3[0];
   assert.strictEqual(loc3.range.start.line, 19);
   assert.strictEqual(loc3.range.start.character, 8);
   assert.strictEqual(loc3.range.end.line, 19);
@@ -158,8 +165,8 @@ async function testdefs(docUri: vscode.Uri) {
     'vscode.executeDefinitionProvider',
     docUri,
     pos4
-  )) as vscode.Definition[];
-  const loc4 = actualdef4[0] as vscode.Location;
+  )) as vscode.Location[];
+  const loc4 = actualdef4[0];
   assert.strictEqual(loc4.range.start.line, 23);
   assert.strictEqual(loc4.range.start.character, 8);
   assert.strictEqual(loc4.range.end.line, 23);
@@ -171,8 +178,8 @@ async function testdefs(docUri: vscode.Uri) {
     'vscode.executeDefinitionProvider',
     docUri,
     pos5
-  )) as vscode.Definition[];
-  const loc5 = actualdef5[0] as vscode.Location;
+  )) as vscode.Location[];
+  const loc5 = actualdef5[0];
   assert.strictEqual(loc5.range.start.line, 24);
   assert.strictEqual(loc5.range.start.character, 8);
   assert.strictEqual(loc5.range.end.line, 24);
@@ -188,8 +195,8 @@ async function testtypedefs(docUri: vscode.Uri) {
     'vscode.executeTypeDefinitionProvider',
     docUri,
     pos0,
-  )) as vscode.Definition[];
-  const loc0 = actualtypedef0[0] as vscode.Location;
+  )) as vscode.Location[];
+  const loc0 = actualtypedef0[0];
   assert.strictEqual(loc0.range.start.line, 22);
   assert.strictEqual(loc0.range.start.character, 11);
   assert.strictEqual(loc0.range.end.line, 22);
@@ -201,8 +208,8 @@ async function testtypedefs(docUri: vscode.Uri) {
     'vscode.executeTypeDefinitionProvider',
     docUri,
     pos1,
-  )) as vscode.Definition[];
-  const loc1 = actualtypedef1[0] as vscode.Location;
+  )) as vscode.Location[];
+  const loc1 = actualtypedef1[0];
   assert.strictEqual(loc1.range.start.line, 7);
   assert.strictEqual(loc1.range.start.character, 4);
   assert.strictEqual(loc1.range.end.line, 21);
@@ -210,6 +217,59 @@ async function testtypedefs(docUri: vscode.Uri) {
   assert.strictEqual(loc1.uri.path, docUri.path);
 }
 
+async function testdecls(docUri: vscode.Uri) {
+  await activate(docUri);
+
+  const pos0 = new vscode.Position(6, 14);
+  const actualdecl0 = (await vscode.commands.executeCommand(
+    'vscode.executeDeclarationProvider',
+    docUri,
+    pos0,
+  )) as vscode.Location[];
+  assert.strictEqual(actualdecl0.length, 2);
+  const loc00 = actualdecl0[0];
+  assert.strictEqual(loc00.range.start.line, 12);
+  assert.strictEqual(loc00.range.start.character, 4);
+  assert.strictEqual(loc00.range.end.line, 12);
+  assert.strictEqual(loc00.range.end.character, 61);
+  assert.strictEqual(loc00.uri.path, docUri.path);
+  const loc01 = actualdecl0[1];
+  assert.strictEqual(loc01.range.start.line, 22);
+  assert.strictEqual(loc01.range.start.character, 4);
+  assert.strictEqual(loc01.range.end.line, 22);
+  assert.strictEqual(loc01.range.end.character, 61);
+  assert.strictEqual(loc01.uri.path, docUri.path);
+
+  const pos1 = new vscode.Position(12, 14);
+  const actualdecl1 = (await vscode.commands.executeCommand(
+    'vscode.executeDeclarationProvider',
+    docUri,
+    pos1,
+  )) as vscode.Location[];
+  assert.strictEqual(actualdecl1.length, 1);
+  const loc10 = actualdecl1[0];
+  assert.strictEqual(loc10.range.start.line, 32);
+  assert.strictEqual(loc10.range.start.character, 4);
+  assert.strictEqual(loc10.range.end.line, 32);
+  assert.strictEqual(loc10.range.end.character, 52);
+  assert.strictEqual(loc10.uri.path, docUri.path);
+
+  const pos2 = new vscode.Position(22, 14);
+  const actualdecl2 = (await vscode.commands.executeCommand(
+    'vscode.executeDeclarationProvider',
+    docUri,
+    pos2,
+  )) as vscode.Location[];
+  assert.strictEqual(actualdecl2.length, 1);
+  const loc20 = actualdecl2[0];
+  assert.strictEqual(loc20.range.start.line, 32);
+  assert.strictEqual(loc20.range.start.character, 4);
+  assert.strictEqual(loc20.range.end.line, 32);
+  assert.strictEqual(loc20.range.end.character, 52);
+  assert.strictEqual(loc20.uri.path, docUri.path);
+}
+
+
 async function testimpls(docUri: vscode.Uri) {
   await activate(docUri);
 
@@ -218,19 +278,19 @@ async function testimpls(docUri: vscode.Uri) {
     'vscode.executeImplementationProvider',
     docUri,
     pos0,
-  )) as vscode.Definition[];
+  )) as vscode.Location[];
   assert.strictEqual(actualimpl0.length, 2);
-  const loc00 = actualimpl0[0] as vscode.Location;
+  const loc00 = actualimpl0[0];
   assert.strictEqual(loc00.range.start.line, 1);
   assert.strictEqual(loc00.range.start.character, 4);
   assert.strictEqual(loc00.range.end.line, 1);
   assert.strictEqual(loc00.range.end.character, 42);
   assert.strictEqual(loc00.uri.path, docUri.path);
-  const loc01 = actualimpl0[1] as vscode.Location;
+  const loc01 = actualimpl0[1];
   assert.strictEqual(loc01.range.start.line, 6);
   assert.strictEqual(loc01.range.start.character, 4);
   assert.strictEqual(loc01.range.end.line, 6);
-  assert.strictEqual(loc01.range.end.character, 61);
+  assert.strictEqual(loc01.range.end.character, 65);
   assert.strictEqual(loc01.uri.path, docUri.path);
 
 
@@ -239,15 +299,15 @@ async function testimpls(docUri: vscode.Uri) {
     'vscode.executeImplementationProvider',
     docUri,
     pos1,
-  )) as vscode.Definition[];
+  )) as vscode.Location[];
   assert.strictEqual(actualimpl1.length, 2);
-  const loc10 = actualimpl1[0] as vscode.Location;
+  const loc10 = actualimpl1[0];
   assert.strictEqual(loc10.range.start.line, 12);
   assert.strictEqual(loc10.range.start.character, 4);
   assert.strictEqual(loc10.range.end.line, 12);
-  assert.strictEqual(loc10.range.end.character, 52);
+  assert.strictEqual(loc10.range.end.character, 61);
   assert.strictEqual(loc10.uri.path, docUri.path);
-  const loc11 = actualimpl1[1] as vscode.Location;
+  const loc11 = actualimpl1[1];
   assert.strictEqual(loc11.range.start.line, 16);
   assert.strictEqual(loc11.range.start.character, 4);
   assert.strictEqual(loc11.range.end.line, 16);
@@ -260,15 +320,15 @@ async function testimpls(docUri: vscode.Uri) {
     'vscode.executeImplementationProvider',
     docUri,
     pos2,
-  )) as vscode.Definition[];
+  )) as vscode.Location[];
   assert.strictEqual(actualimpl2.length, 2);
-  const loc20 = actualimpl2[0] as vscode.Location;
+  const loc20 = actualimpl2[0];
   assert.strictEqual(loc20.range.start.line, 22);
   assert.strictEqual(loc20.range.start.character, 4);
   assert.strictEqual(loc20.range.end.line, 22);
-  assert.strictEqual(loc20.range.end.character, 52);
+  assert.strictEqual(loc20.range.end.character, 61);
   assert.strictEqual(loc20.uri.path, docUri.path);
-  const loc21 = actualimpl2[1] as vscode.Location;
+  const loc21 = actualimpl2[1];
   assert.strictEqual(loc21.range.start.line, 26);
   assert.strictEqual(loc21.range.start.character, 4);
   assert.strictEqual(loc21.range.end.line, 26);
@@ -284,33 +344,33 @@ async function testrefs(docUri: vscode.Uri) {
     'vscode.executeReferenceProvider',
     docUri,
     pos0,
-  )) as vscode.Definition[];
+  )) as vscode.Location[];
   assert.strictEqual(actualref0.length, 5);
-  const loc00 = actualref0[0] as vscode.Location;
+  const loc00 = actualref0[0];
   assert.strictEqual(loc00.range.start.line, 27);
   assert.strictEqual(loc00.range.start.character, 50);
   assert.strictEqual(loc00.range.end.line, 27);
   assert.strictEqual(loc00.range.end.character, 55);
   assert.strictEqual(loc00.uri.path, docUri.path);
-  const loc01 = actualref0[1] as vscode.Location;
+  const loc01 = actualref0[1];
   assert.strictEqual(loc01.range.start.line, 30);
   assert.strictEqual(loc01.range.start.character, 16);
   assert.strictEqual(loc01.range.end.line, 30);
   assert.strictEqual(loc01.range.end.character, 21);
   assert.strictEqual(loc01.uri.path, docUri.path);
-  const loc02 = actualref0[2] as vscode.Location;
+  const loc02 = actualref0[2];
   assert.strictEqual(loc02.range.start.line, 33);
   assert.strictEqual(loc02.range.start.character, 16);
   assert.strictEqual(loc02.range.end.line, 33);
   assert.strictEqual(loc02.range.end.character, 21);
   assert.strictEqual(loc02.uri.path, docUri.path);
-  const loc03 = actualref0[3] as vscode.Location;
+  const loc03 = actualref0[3];
   assert.strictEqual(loc03.range.start.line, 36);
   assert.strictEqual(loc03.range.start.character, 16);
   assert.strictEqual(loc03.range.end.line, 36);
   assert.strictEqual(loc03.range.end.character, 21);
   assert.strictEqual(loc03.uri.path, docUri.path);
-  const loc04 = actualref0[4] as vscode.Location;
+  const loc04 = actualref0[4];
   assert.strictEqual(loc04.range.start.line, 39);
   assert.strictEqual(loc04.range.start.character, 16);
   assert.strictEqual(loc04.range.end.line, 39);
@@ -322,39 +382,39 @@ async function testrefs(docUri: vscode.Uri) {
     'vscode.executeReferenceProvider',
     docUri,
     pos1,
-  )) as vscode.Definition[];
+  )) as vscode.Location[];
   assert.strictEqual(actualref1.length, 6);
-  const loc10 = actualref1[0] as vscode.Location;
+  const loc10 = actualref1[0];
   assert.strictEqual(loc10.range.start.line, 27);
   assert.strictEqual(loc10.range.start.character, 24);
   assert.strictEqual(loc10.range.end.line, 27);
   assert.strictEqual(loc10.range.end.character, 25);
   assert.strictEqual(loc10.uri.path, docUri.path);
-  const loc11 = actualref1[1] as vscode.Location;
+  const loc11 = actualref1[1];
   assert.strictEqual(loc11.range.start.line, 28);
   assert.strictEqual(loc11.range.start.character, 12);
   assert.strictEqual(loc11.range.end.line, 28);
   assert.strictEqual(loc11.range.end.character, 13);
   assert.strictEqual(loc11.uri.path, docUri.path);
-  const loc12 = actualref1[2] as vscode.Location;
+  const loc12 = actualref1[2];
   assert.strictEqual(loc12.range.start.line, 29);
   assert.strictEqual(loc12.range.start.character, 16);
   assert.strictEqual(loc12.range.end.line, 29);
   assert.strictEqual(loc12.range.end.character, 17);
   assert.strictEqual(loc12.uri.path, docUri.path);
-  const loc13 = actualref1[3] as vscode.Location;
+  const loc13 = actualref1[3];
   assert.strictEqual(loc13.range.start.line, 32);
   assert.strictEqual(loc13.range.start.character, 16);
   assert.strictEqual(loc13.range.end.line, 32);
   assert.strictEqual(loc13.range.end.character, 17);
   assert.strictEqual(loc13.uri.path, docUri.path);
-  const loc14 = actualref1[4] as vscode.Location;
+  const loc14 = actualref1[4];
   assert.strictEqual(loc14.range.start.line, 35);
   assert.strictEqual(loc14.range.start.character, 16);
   assert.strictEqual(loc14.range.end.line, 35);
   assert.strictEqual(loc14.range.end.character, 17);
   assert.strictEqual(loc14.uri.path, docUri.path);
-  const loc15 = actualref1[5] as vscode.Location;
+  const loc15 = actualref1[5];
   assert.strictEqual(loc15.range.start.line, 38);
   assert.strictEqual(loc15.range.start.character, 16);
   assert.strictEqual(loc15.range.end.line, 38);

+ 11 - 5
vscode/src/testFixture/impls.sol

@@ -4,13 +4,13 @@ contract a is b1, b2 {
         return super.foo();
     }
 
-    function foo() internal override(b1, b2) returns (uint64) {
+    function foo() internal override(b1, b2, b3) returns (uint64) {
         return 2;
     }
 }
 
-abstract contract b1 {
-    function foo() internal virtual returns (uint64) {
+abstract contract b1 is b3 {
+    function foo() internal virtual override returns (uint64) {
         return 100;
     }
 
@@ -19,8 +19,8 @@ abstract contract b1 {
     }
 }
 
-abstract contract b2 {
-    function foo() internal virtual returns (uint64) {
+abstract contract b2 is b3 {
+    function foo() internal virtual override returns (uint64) {
         return 200;
     }
     
@@ -28,3 +28,9 @@ abstract contract b2 {
         return 25;
     }
 }
+
+abstract contract b3 {
+    function foo() internal virtual returns (uint64) {
+        return 400;
+    }
+}