Skip to content

[feat] syntactic semantic tokens #4672

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions ghcide/src/Development/IDE/GHC/Compat/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -630,11 +630,47 @@ instance HasSrcSpan SrcSpan where
instance HasSrcSpan (SrcLoc.GenLocated SrcSpan a) where
getLoc = GHC.getLoc

#if MIN_VERSION_ghc(9,11,0)
instance HasSrcSpan (GHC.EpToken sym) where
getLoc = GHC.getHasLoc
#elif MIN_VERSION_ghc(9,9,0)
instance HasSrcSpan (GHC.EpToken sym) where
getLoc = GHC.getHasLoc . \case
GHC.NoEpTok -> Nothing
GHC.EpTok loc -> Just loc
#endif

#if MIN_VERSION_ghc(9,9,0)
instance HasSrcSpan (EpAnn a) where
getLoc = GHC.getHasLoc
#endif

#if !MIN_VERSION_ghc(9,11,0)
instance HasSrcSpan GHC.AddEpAnn where
getLoc (GHC.AddEpAnn _ loc) = getLoc loc

instance HasSrcSpan GHC.EpaLocation where
#if MIN_VERSION_ghc(9,9,0)
getLoc loc = GHC.getHasLoc loc
#else
getLoc loc = case loc of
GHC.EpaSpan span bufspan -> RealSrcSpan span $ case bufspan of Strict.Nothing -> Nothing; Strict.Just a -> Just a
GHC.EpaDelta {} -> panic "compiler inserted epadelta in EpaLocation"
#endif
#endif

instance HasSrcSpan GHC.LEpaComment where
#if MIN_VERSION_ghc(9,9,0)
getLoc :: GHC.LEpaComment -> SrcSpan
getLoc (GHC.L l _) = case l of
SrcLoc.EpaDelta {} -> panic "compiler inserted epadelta into NoCommentsLocation"
SrcLoc.EpaSpan span -> span
#else
getLoc :: GHC.LEpaComment -> SrcSpan
getLoc c = case c of
SrcLoc.L (GHC.Anchor realSpan _) _ -> RealSrcSpan realSpan Nothing
#endif

#if MIN_VERSION_ghc(9,9,0)
instance HasSrcSpan (SrcLoc.GenLocated (EpAnn ann) a) where
getLoc (L l _) = getLoc l
Expand Down
1 change: 1 addition & 0 deletions haskell-language-server.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -1729,6 +1729,7 @@ library hls-semantic-tokens-plugin
, containers
, extra
, text-rope
, ghc
, mtl >= 2.2
, ghcide == 2.11.0.0
, hls-plugin-api == 2.11.0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ descriptor recorder plId =
{ Ide.Types.pluginHandlers =
mkPluginHandler SMethod_TextDocumentSemanticTokensFull (Internal.semanticTokensFull recorder)
<> mkPluginHandler SMethod_TextDocumentSemanticTokensFullDelta (Internal.semanticTokensFullDelta recorder),
Ide.Types.pluginRules = Internal.getSemanticTokensRule recorder,
Ide.Types.pluginRules = Internal.getSemanticTokensRule recorder <> Internal.getSyntacticTokensRule recorder,
pluginConfigDescriptor =
defaultConfigDescriptor
{ configInitialGenericConfig = (configInitialGenericConfig defaultConfigDescriptor) {plcGlobalOn = False}
{ configInitialGenericConfig = (configInitialGenericConfig defaultConfigDescriptor) {plcGlobalOn = True}
, configCustomConfig = mkCustomConfig Internal.semanticConfigProperties
}
}
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UnicodeSyntax #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE ImpredicativeTypes #-}
{-# LANGUAGE LiberalTypeSynonyms #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE OverloadedLabels #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UnicodeSyntax #-}
{-# LANGUAGE ViewPatterns #-}

-- |
-- This module provides the core functionality of the plugin.
module Ide.Plugin.SemanticTokens.Internal (semanticTokensFull, getSemanticTokensRule, semanticConfigProperties, semanticTokensFullDelta) where
module Ide.Plugin.SemanticTokens.Internal (semanticTokensFull, getSemanticTokensRule, getSyntacticTokensRule, semanticConfigProperties, semanticTokensFullDelta) where

import Control.Concurrent.STM (stateTVar)
import Control.Concurrent.STM.Stats (atomically)
Expand All @@ -20,20 +28,27 @@ import Control.Monad.Except (ExceptT, liftEither,
import Control.Monad.IO.Class (MonadIO (..))
import Control.Monad.Trans (lift)
import Control.Monad.Trans.Except (runExceptT)
import Control.Monad.Trans.Maybe
import Data.Data (Data (..))
import Data.List
import qualified Data.Map.Strict as M
import Data.Maybe
import Data.Semigroup (First (..))
import Data.Text (Text)
import qualified Data.Text as T
import Development.IDE (Action,
GetDocMap (GetDocMap),
GetHieAst (GetHieAst),
GetParsedModuleWithComments (..),
HieAstResult (HAR, hieAst, hieModule, refMap),
IdeResult, IdeState,
Priority (..),
Recorder, Rules,
WithPriority,
cmapWithPrio, define,
fromNormalizedFilePath,
hieKind)
hieKind,
srcSpanToRange,
useWithStale)
import Development.IDE.Core.PluginUtils (runActionE, useE,
useWithStaleE)
import Development.IDE.Core.Rules (toIdeResult)
Expand All @@ -43,8 +58,9 @@ import Development.IDE.Core.Shake (ShakeExtras (..),
getVirtualFile)
import Development.IDE.GHC.Compat hiding (Warning)
import Development.IDE.GHC.Compat.Util (mkFastString)
import GHC.Parser.Annotation
import Ide.Logger (logWith)
import Ide.Plugin.Error (PluginError (PluginInternalError),
import Ide.Plugin.Error (PluginError (PluginInternalError, PluginRuleFailed),
getNormalizedFilePathE,
handleMaybe,
handleMaybeM)
Expand All @@ -58,10 +74,17 @@ import qualified Language.LSP.Protocol.Lens as L
import Language.LSP.Protocol.Message (MessageResult,
Method (Method_TextDocumentSemanticTokensFull, Method_TextDocumentSemanticTokensFullDelta))
import Language.LSP.Protocol.Types (NormalizedFilePath,
Range,
SemanticTokens,
fromNormalizedFilePath,
type (|?) (InL, InR))
import Prelude hiding (span)
import qualified StmContainers.Map as STM
import Type.Reflection (Typeable, eqTypeRep,
pattern App,
type (:~~:) (HRefl),
typeOf, typeRep,
withTypeable)


$mkSemanticConfigFunctions
Expand All @@ -75,8 +98,17 @@ computeSemanticTokens recorder pid _ nfp = do
config <- lift $ useSemanticConfigAction pid
logWith recorder Debug (LogConfig config)
semanticId <- lift getAndIncreaseSemanticTokensId
(RangeHsSemanticTokenTypes {rangeSemanticList}, mapping) <- useWithStaleE GetSemanticTokens nfp
withExceptT PluginInternalError $ liftEither $ rangeSemanticsSemanticTokens semanticId config mapping rangeSemanticList

(sortOn fst -> tokenList, First mapping) <- do
rangesyntacticTypes <- lift $ useWithStale GetSyntacticTokens nfp
rangesemanticTypes <- lift $ useWithStale GetSemanticTokens nfp
let mk w u (toks, mapping) = (map (fmap w) $ u toks, First mapping)
maybeToExceptT (PluginRuleFailed "no syntactic nor semantic tokens") $ hoistMaybe $
(mk HsSyntacticTokenType rangeSyntacticList <$> rangesyntacticTypes)
<> (mk HsSemanticTokenType rangeSemanticList <$> rangesemanticTypes)

-- NOTE: rangeSemanticsSemanticTokens actually assumes that the tokesn are in order. that means they have to be sorted by position
withExceptT PluginInternalError $ liftEither $ rangeSemanticsSemanticTokens semanticId config mapping tokenList

semanticTokensFull :: Recorder (WithPriority SemanticLog) -> PluginMethodHandler IdeState 'Method_TextDocumentSemanticTokensFull
semanticTokensFull recorder state pid param = runActionE "SemanticTokens.semanticTokensFull" state computeSemanticTokensFull
Expand Down Expand Up @@ -130,6 +162,102 @@ getSemanticTokensRule recorder =
let hsFinder = idSemantic getTyThingMap (hieKindFunMasksKind hieKind) refMap
return $ computeRangeHsSemanticTokenTypeList hsFinder virtualFile ast

getSyntacticTokensRule :: Recorder (WithPriority SemanticLog) -> Rules ()
getSyntacticTokensRule recorder =
define (cmapWithPrio LogShake recorder) $ \GetSyntacticTokens nfp -> handleError recorder $ do
(parsedModule, _) <- withExceptT LogDependencyError $ useWithStaleE GetParsedModuleWithComments nfp
let tokList = computeRangeHsSyntacticTokenTypeList parsedModule
logWith recorder Debug $ LogSyntacticTokens tokList
pure tokList

astTraversalWith :: forall b r. Data b => b -> (forall a. Data a => a -> [r]) -> [r]
astTraversalWith ast f = mconcat $ flip gmapQ ast \y -> f y <> astTraversalWith y f

{-# inline extractTyToTy #-}
extractTyToTy :: forall f a. (Typeable f, Data a) => a -> Maybe (forall r. (forall b. Typeable b => f b -> r) -> r)
extractTyToTy node
| App conRep argRep <- typeOf node
, Just HRefl <- eqTypeRep conRep (typeRep @f)
= Just $ withTypeable argRep \k -> k node
| otherwise = Nothing

{-# inline extractTy #-}
extractTy :: forall b a. (Typeable b, Data a) => a -> Maybe b
extractTy node
| Just HRefl <- eqTypeRep (typeRep @b) (typeOf node)
= Just node
| otherwise = Nothing

computeRangeHsSyntacticTokenTypeList :: ParsedModule -> RangeHsSyntacticTokenTypes
computeRangeHsSyntacticTokenTypeList ParsedModule {pm_parsed_source} =
let toks = astTraversalWith pm_parsed_source \node -> mconcat
[
#if MIN_VERSION_ghc(9,9,0)
maybeToList $ mkFromLocatable TKeyword . (\k -> k \x k' -> k' x) =<< extractTyToTy @EpToken node,
#endif
#if !MIN_VERSION_ghc(9,11,0)
maybeToList $ mkFromLocatable TKeyword . (\x k -> k x) =<< extractTy @AddEpAnn node,
do
EpAnnImportDecl i p s q pkg a <- maybeToList $ extractTy @EpAnnImportDecl node

mapMaybe (mkFromLocatable TKeyword . (\x k -> k x)) $ catMaybes $ [Just i, s, q, pkg, a] <> foldMap (\(l, l') -> [Just l, Just l']) p,
#endif
maybeToList $ mkFromLocatable TComment . (\x k -> k x) =<< extractTy @LEpaComment node,
do
L loc expr <- maybeToList $ extractTy @(LHsExpr GhcPs) node
let fromSimple = maybeToList . flip mkFromLocatable \k -> k loc
case expr of
HsOverLabel {} -> fromSimple TStringLit
HsOverLit _ (OverLit _ lit) -> fromSimple case lit of
HsIntegral {} -> TNumberLit
HsFractional {} -> TNumberLit

HsIsString {} -> TStringLit
HsLit _ lit -> fromSimple case lit of
HsChar {} -> TCharLit
HsCharPrim {} -> TCharLit

HsInt {} -> TNumberLit
HsInteger {} -> TNumberLit
HsIntPrim {} -> TNumberLit
HsWordPrim {} -> TNumberLit
#if MIN_VERSION_ghc(9,9,0)
HsWord8Prim {} -> TNumberLit
HsWord16Prim {} -> TNumberLit
HsWord32Prim {} -> TNumberLit
#endif
HsWord64Prim {} -> TNumberLit
#if MIN_VERSION_ghc(9,9,0)
HsInt8Prim {} -> TNumberLit
HsInt16Prim {} -> TNumberLit
HsInt32Prim {} -> TNumberLit
#endif
HsInt64Prim {} -> TNumberLit
HsFloatPrim {} -> TNumberLit
HsDoublePrim {} -> TNumberLit
HsRat {} -> TNumberLit

HsString {} -> TStringLit
HsStringPrim {} -> TStringLit
#if MIN_VERSION_ghc(9,11,0)
HsMultilineString {} -> TStringLit
#endif
HsGetField _ _ field -> maybeToList $ mkFromLocatable TRecordSelector \k -> k field
#if MIN_VERSION_ghc(9,11,0)
HsProjection _ projs -> foldMap (\dotFieldOcc -> maybeToList $ mkFromLocatable TRecordSelector \k -> k dotFieldOcc.dfoLabel) projs
#else
HsProjection _ projs -> foldMap (\proj -> maybeToList $ mkFromLocatable TRecordSelector \k -> k proj) projs
#endif
_ -> []
]
in RangeHsSyntacticTokenTypes toks

{-# inline mkFromLocatable #-}
mkFromLocatable
:: HsSyntacticTokenType
-> (forall r. (forall a. HasSrcSpan a => a -> r) -> r)
-> Maybe (Range, HsSyntacticTokenType)
mkFromLocatable tt w = w \tok -> let mrange = srcSpanToRange $ getLoc tok in fmap (, tt) mrange

-- taken from /haskell-language-server/plugins/hls-code-range-plugin/src/Ide/Plugin/CodeRange/Rules.hs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
module Ide.Plugin.SemanticTokens.Mappings where

import qualified Data.Array as A
import Data.Function
import Data.List.Extra (chunksOf, (!?))
import qualified Data.Map.Strict as Map
import Data.Maybe (mapMaybe)
Expand All @@ -39,28 +40,34 @@ nameInfixOperator _ = Nothing
-- * 1. Mapping semantic token type to and from the LSP default token type.

-- | map from haskell semantic token type to LSP default token type
toLspTokenType :: SemanticTokensConfig -> HsSemanticTokenType -> SemanticTokenTypes
toLspTokenType conf tk = case tk of
TFunction -> stFunction conf
TVariable -> stVariable conf
TClassMethod -> stClassMethod conf
TTypeVariable -> stTypeVariable conf
TDataConstructor -> stDataConstructor conf
TClass -> stClass conf
TTypeConstructor -> stTypeConstructor conf
TTypeSynonym -> stTypeSynonym conf
TTypeFamily -> stTypeFamily conf
TRecordField -> stRecordField conf
TPatternSynonym -> stPatternSynonym conf
TModule -> stModule conf
TOperator -> stOperator conf
toLspTokenType :: SemanticTokensConfig -> HsTokenType -> SemanticTokenTypes
toLspTokenType conf tk = conf & case tk of
HsSemanticTokenType TFunction -> stFunction
HsSemanticTokenType TVariable -> stVariable
HsSemanticTokenType TClassMethod -> stClassMethod
HsSemanticTokenType TTypeVariable -> stTypeVariable
HsSemanticTokenType TDataConstructor -> stDataConstructor
HsSemanticTokenType TClass -> stClass
HsSemanticTokenType TTypeConstructor -> stTypeConstructor
HsSemanticTokenType TTypeSynonym -> stTypeSynonym
HsSemanticTokenType TTypeFamily -> stTypeFamily
HsSemanticTokenType TRecordField -> stRecordField
HsSemanticTokenType TPatternSynonym -> stPatternSynonym
HsSemanticTokenType TModule -> stModule
HsSemanticTokenType TOperator -> stOperator
HsSyntacticTokenType TKeyword -> stKeyword
HsSyntacticTokenType TComment -> stComment
HsSyntacticTokenType TStringLit -> stStringLit
HsSyntacticTokenType TCharLit -> stCharLit
HsSyntacticTokenType TNumberLit -> stNumberLit
HsSyntacticTokenType TRecordSelector -> stRecordSelector

lspTokenReverseMap :: SemanticTokensConfig -> Map.Map SemanticTokenTypes HsSemanticTokenType
lspTokenReverseMap config
| length xs /= Map.size mr = error "lspTokenReverseMap: token type mapping is not bijection"
| otherwise = mr
where xs = enumFrom minBound
mr = Map.fromList $ map (\x -> (toLspTokenType config x, x)) xs
mr = Map.fromList $ map (\x -> (toLspTokenType config (HsSemanticTokenType x), x)) xs

lspTokenTypeHsTokenType :: SemanticTokensConfig -> SemanticTokenTypes -> Maybe HsSemanticTokenType
lspTokenTypeHsTokenType cf tk = Map.lookup tk (lspTokenReverseMap cf)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import Development.IDE.GHC.Compat
import Ide.Plugin.SemanticTokens.Mappings
import Ide.Plugin.SemanticTokens.Types (HieFunMaskKind,
HsSemanticTokenType (TModule),
HsTokenType,
RangeSemanticTokenTypeList,
SemanticTokenId,
SemanticTokensConfig)
Expand Down Expand Up @@ -66,11 +67,11 @@ nameSemanticFromHie hieKind rm n = idSemanticFromRefMap rm (Right n)

-------------------------------------------------

rangeSemanticsSemanticTokens :: SemanticTokenId -> SemanticTokensConfig -> PositionMapping -> RangeSemanticTokenTypeList -> Either Text SemanticTokens
rangeSemanticsSemanticTokens :: SemanticTokenId -> SemanticTokensConfig -> PositionMapping -> [(Range, HsTokenType)] -> Either Text SemanticTokens
rangeSemanticsSemanticTokens sid stc mapping =
makeSemanticTokensWithId (Just sid) . mapMaybe (\(ran, tk) -> toAbsSemanticToken <$> toCurrentRange mapping ran <*> return tk)
where
toAbsSemanticToken :: Range -> HsSemanticTokenType -> SemanticTokenAbsolute
toAbsSemanticToken :: Range -> HsTokenType -> SemanticTokenAbsolute
toAbsSemanticToken (Range (Position startLine startColumn) (Position _endLine endColumn)) tokenType =
let len = endColumn - startColumn
in SemanticTokenAbsolute
Expand Down
Loading
Loading