Skip to content

Commit

Permalink
feat: drop source groups in favor of fallback_for (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
Saghen authored Oct 11, 2024
1 parent b330b61 commit 1f0c0f3
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 103 deletions.
75 changes: 39 additions & 36 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,48 +227,51 @@ MiniDeps.add({
-- returns no completion items
-- WARN: This API will have breaking changes during the beta
providers = {
{
{ 'blink.cmp.sources.lsp' },
{ 'blink.cmp.sources.path' },
{ 'blink.cmp.sources.snippets', score_offset = -3 },
},
{ { 'blink.cmp.sources.buffer' } },
{ 'blink.cmp.sources.lsp', name = 'LSP' },
{ 'blink.cmp.sources.path', name = 'Path', score_offset = 3 },
{ 'blink.cmp.sources.snippets', score_offset = -3 },
{ 'blink.cmp.sources.buffer', name = 'Buffer', fallback_for = { 'LSP' } },
},
-- FOR REF: full example
providers = {
-- all of these properties work on every source
{
'blink.cmp.sources.lsp',
name = 'LSP',
keyword_length = 0,
score_offset = 0,
trigger_characters = { 'f', 'o', 'o' },
},
-- the following two sources have additional options
{
-- all of these properties work on every source
{
'blink.cmp.sources.lsp',
keyword_length = 0,
score_offset = 0,
trigger_characters = { 'f', 'o', 'o' },
opts = {},
},
-- the follow two sources have additional options
{
'blink.cmp.sources.path',
opts = {
trailing_slash = false,
label_trailing_slash = true,
get_cwd = function(context) return vim.fn.expand(('#%d:p:h'):format(context.bufnr)) end,
show_hidden_files_by_default = true,
}
},
{
'blink.cmp.sources.snippets',
score_offset = -3,
-- similar to https://github.com/garymjr/nvim-snippets
opts = {
friendly_snippets = true,
search_paths = { vim.fn.stdpath('config') .. '/snippets' },
global_snippets = { 'all' },
extended_filetypes = {},
ignored_filetypes = {},
},
'blink.cmp.sources.path',
name = 'Path',
score_offset = 3,
opts = {
trailing_slash = false,
label_trailing_slash = true,
get_cwd = function(context) return vim.fn.expand(('#%d:p:h'):format(context.bufnr)) end,
show_hidden_files_by_default = true,
}
},
{
'blink.cmp.sources.snippets',
name = 'Snippets',
score_offset = -3,
-- similar to https://github.com/garymjr/nvim-snippets
opts = {
friendly_snippets = true,
search_paths = { vim.fn.stdpath('config') .. '/snippets' },
global_snippets = { 'all' },
extended_filetypes = {},
ignored_filetypes = {},
},
},
{ { 'blink.cmp.sources.buffer' } }
{
'blink.cmp.sources.buffer',
name = 'Buffer',
fallback_for = { 'LSP' },
}
}
},

Expand Down
14 changes: 7 additions & 7 deletions lua/blink/cmp/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@
--- @field signature_help? blink.cmp.SignatureHelpTriggerConfig

--- @class blink.cmp.SourceConfig
--- @field providers? blink.cmp.SourceProviderConfig[][]
--- @field providers? blink.cmp.SourceProviderConfig[]
---
--- @class blink.cmp.SourceProviderConfig
--- @field [1]? string
--- @field name string
--- @field fallback_for? string[] | nil
--- @field keyword_length? number | nil
--- @field score_offset? number | nil
--- @field deduplicate? blink.cmp.DeduplicateConfig | nil
Expand Down Expand Up @@ -228,12 +230,10 @@ local config = {
-- returns no completion items
-- WARN: This API will have breaking changes during the beta
providers = {
{
{ 'blink.cmp.sources.lsp' },
{ 'blink.cmp.sources.path' },
{ 'blink.cmp.sources.snippets', score_offset = -2 },
},
{ { 'blink.cmp.sources.buffer' } },
{ 'blink.cmp.sources.lsp', name = 'LSP' },
{ 'blink.cmp.sources.path', name = 'Path', score_offset = 3 },
{ 'blink.cmp.sources.snippets', name = 'Snippets', score_offset = -3 },
{ 'blink.cmp.sources.buffer', name = 'Buffer', fallback_for = { 'LSP' } },
},
},

Expand Down
4 changes: 2 additions & 2 deletions lua/blink/cmp/sources/lib/async.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
--- @field new fun(fn: fun(resolve: fun(result: any), reject: fun(err: any))): blink.cmp.Task
---
--- @field cancel fun(self: blink.cmp.Task)
--- @field map fun(self: blink.cmp.Task, fn: fun(result: any): blink.cmp.Task | any)
--- @field catch fun(self: blink.cmp.Task, fn: fun(err: any): blink.cmp.Task | any)
--- @field map fun(self: blink.cmp.Task, fn: fun(result: any): blink.cmp.Task | any): blink.cmp.Task
--- @field catch fun(self: blink.cmp.Task, fn: fun(err: any): blink.cmp.Task | any): blink.cmp.Task
---
--- @field on_completion fun(self: blink.cmp.Task, cb: fun(result: any))
--- @field on_failure fun(self: blink.cmp.Task, cb: fun(err: any))
Expand Down
94 changes: 58 additions & 36 deletions lua/blink/cmp/sources/lib/context.lua
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
local utils = require('blink.cmp.sources.lib.utils')
local async = require('blink.cmp.sources.lib.async')
local sources_context = {}

--- @param context blink.cmp.Context
--- @param sources_groups blink.cmp.Source[][]
--- @param sources blink.cmp.SourceProvider[]
--- @param on_completions_callback fun(context: blink.cmp.Context, items: blink.cmp.CompletionItem[])
function sources_context.new(context, sources_groups, on_completions_callback)
function sources_context.new(context, sources, on_completions_callback)
local self = setmetatable({}, { __index = sources_context })
self.id = context.id
self.sources_groups = sources_groups
self.sources = sources

self.active_request = nil
self.queued_request_context = nil
self.last_sources_group_idx = nil
--- @type fun(context: blink.cmp.Context, items: blink.cmp.CompletionItem[])
self.on_completions_callback = on_completions_callback

Expand All @@ -27,26 +27,12 @@ function sources_context:get_completions(context)
return
end

-- Create a task to get the completions for the first sources group,
-- falling back to the next sources group iteratively if there are no items
local request = self:get_completions_for_group(1, self.sources_groups[1], context)
for idx, sources_group in ipairs(self.sources_groups) do
if idx > 1 then
request = request:map(function(res)
if #res.items > 0 then return res end
return self:get_completions_for_group(idx, sources_group, context)
end)
end
end

-- Send response upstream and run the queued request, if it exists
self.active_request = request:map(function(response)
-- Create a task to get the completions, send responses upstream
-- and run the queued request, if it exists
self.active_request = self:get_completions_for_sources(self.sources, context):map(function(response)
self.active_request = nil
-- only send upstream if the response contains something new
if not response.is_cached or response.sources_group_idx ~= self.last_sources_group_idx then
self.on_completions_callback(context, response.items)
end
self.last_sources_group_idx = response.sources_group_idx
if not response.is_cached then self.on_completions_callback(context, response.items) end

-- run the queued request, if it exists
if self.queued_request_context ~= nil then
Expand All @@ -57,28 +43,26 @@ function sources_context:get_completions(context)
end)
end

--- @param sources_group_idx number
--- @param sources_group blink.cmp.Source[]
--- @param sources blink.cmp.SourceProvider[]
--- @param context blink.cmp.Context
--- @return blink.cmp.Task
function sources_context:get_completions_for_group(sources_group_idx, sources_group, context)
-- get completions for each source in the group
function sources_context:get_completions_for_sources(sources, context)
local non_fallback_sources = vim.tbl_filter(function(source) return source.config.fallback_for == nil end, sources)

-- get completions for each non-fallback source
local tasks = vim.tbl_map(function(source)
-- the source indicates we should refetch when this character is typed
local trigger_character = context.trigger.character
and vim.tbl_contains(source:get_trigger_characters(), context.trigger.character)

-- The TriggerForIncompleteCompletions kind is handled by the source itself
-- The TriggerForIncompleteCompletions kind is handled by the source provider itself
local source_context = require('blink.cmp.utils').shallow_copy(context)
source_context.trigger = trigger_character
and { kind = vim.lsp.protocol.CompletionTriggerKind.TriggerCharacter, character = context.trigger.character }
or { kind = vim.lsp.protocol.CompletionTriggerKind.Invoked }

return source:get_completions(source_context):catch(function(err)
vim.print(source.name .. ': failed to get completions with error: ' .. err)
return { is_incomplete_forward = false, is_incomplete_backward = false, items = {} }
end)
end, sources_group)
return self:get_completions_with_fallbacks(source_context, source, sources)
end, non_fallback_sources)

-- wait for all the tasks to complete
return async.task
Expand All @@ -91,21 +75,59 @@ function sources_context:get_completions_for_group(sources_group_idx, sources_gr
for idx, task_result in ipairs(tasks_results) do
if task_result.status == async.STATUS.COMPLETED then
is_cached = is_cached and (task_result.result.is_cached or false)
local source = sources_group[idx]
local source = sources[idx]
--- @type blink.cmp.CompletionResponse
local response = task_result.result
response.items = source:filter_completions(response)
if source:should_show_completions(context, response) then vim.list_extend(items, response.items) end
end
end
return { sources_group_idx = sources_group_idx, is_cached = is_cached, items = items }
return { is_cached = is_cached, items = items }
end)
:catch(function(err)
vim.print('failed to get completions for group with error: ' .. err)
return { sources_group_idx = sources_group_idx, is_cached = false, items = {} }
vim.print('failed to get completions for sources with error: ' .. err)
return { is_cached = false, items = {} }
end)
end

--- Runs the source's get_completions function, falling back to other sources
--- with fallback_for = { source.name } if the source returns no completion items
--- @param context blink.cmp.Context
--- @param source blink.cmp.SourceProvider
--- @param sources blink.cmp.SourceProvider[]
--- @return blink.cmp.Task
--- TODO: When a source has multiple fallbacks, we may end up with duplicate completion items
function sources_context:get_completions_with_fallbacks(context, source, sources)
local fallback_sources = vim.tbl_filter(
function(fallback_source)
return fallback_source.name ~= source.name
and fallback_source.config.fallback_for ~= nil
and vim.tbl_contains(fallback_source.config.fallback_for, source.name)
end,
sources
)

return source:get_completions(context):map(function(response)
-- source returned completions, no need to fallback
if #response.items > 0 or #fallback_sources == 0 then return response end

-- run fallbacks
return async.task
.await_all(vim.tbl_map(function(fallback) return fallback:get_completions(context) end, fallback_sources))
:map(function(task_results)
local successful_task_results = vim.tbl_filter(
function(task_result) return task_result.status == async.STATUS.COMPLETED end,
task_results
)
local fallback_responses = vim.tbl_map(
function(task_result) return task_result.result end,
successful_task_results
)
return utils.concat_responses(fallback_responses)
end)
end)
end

function sources_context:destroy()
self.on_completions_callback = function() end
if self.active_request ~= nil then self.active_request:cancel() end
Expand Down
32 changes: 12 additions & 20 deletions lua/blink/cmp/sources/lib/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,15 @@ local config = require('blink.cmp.config')
local sources = {
current_context = nil,
sources_registered = false,
sources_groups = {},
providers = {},
on_completions_callback = function(_, _) end,
}

function sources.register()
assert(#sources.sources_groups == 0, 'Sources have already been registered')
assert(#sources.providers == 0, 'Sources have already been registered')

for _, sources_group in ipairs(config.sources.providers) do
local group = {}
for _, source_config in ipairs(sources_group) do
table.insert(group, require('blink.cmp.sources.lib.source').new(source_config))
end
table.insert(sources.sources_groups, group)
for _, source_config in ipairs(config.sources.providers) do
table.insert(sources.providers, require('blink.cmp.sources.lib.provider').new(source_config))
end
end

Expand All @@ -29,8 +25,7 @@ function sources.get_trigger_characters()
end

local trigger_characters = {}
-- todo: should this be all source groups?
for _, source in pairs(sources.sources_groups[1]) do
for _, source in pairs(sources.providers) do
local source_trigger_characters = source:get_trigger_characters()
for _, char in ipairs(source_trigger_characters) do
if not blocked_trigger_characters[char] then table.insert(trigger_characters, char) end
Expand All @@ -48,7 +43,7 @@ function sources.request_completions(context)
if is_new_context then
if sources.current_context ~= nil then sources.current_context:destroy() end
sources.current_context =
require('blink.cmp.sources.lib.context').new(context, sources.sources_groups, sources.on_completions_callback)
require('blink.cmp.sources.lib.context').new(context, sources.providers, sources.on_completions_callback)
end

sources.current_context:get_completions(context)
Expand All @@ -68,14 +63,11 @@ end
--- @return fun(): nil Cancelation function
function sources.resolve(item, callback)
local item_source = nil
for _, group in ipairs(sources.sources_groups) do
for _, source in ipairs(group) do
if source.name == item.source then
item_source = source
break
end
for _, source in ipairs(sources.providers) do
if source.name == item.source then
item_source = source
break
end
if item_source ~= nil then break end
end

if item_source == nil then
Expand Down Expand Up @@ -105,7 +97,7 @@ function sources.get_signature_help_trigger_characters()
local retrigger_characters = {}

-- todo: should this be all source groups?
for _, source in ipairs(sources.sources_groups[1]) do
for _, source in ipairs(sources.providers) do
local res = source:get_signature_help_trigger_characters()
for _, char in ipairs(res.trigger_characters) do
if not blocked_trigger_characters[char] then table.insert(trigger_characters, char) end
Expand All @@ -121,7 +113,7 @@ end
--- @param callback fun(signature_helps: lsp.SignatureHelp)
function sources.get_signature_help(context, callback)
local tasks = {}
for _, source in ipairs(sources.sources_groups[1]) do
for _, source in ipairs(sources.providers) do
table.insert(tasks, source:get_signature_help(context))
end
sources.current_signature_help = async.task.await_all(tasks):map(function(tasks_results)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ local source = {}

--- @param config blink.cmp.SourceProviderConfig
function source.new(config)
assert(type(config.name) == 'string', 'Each source in config.sources.providers must have a "name" of type string')

local self = setmetatable({}, { __index = source })
self.name = config[1]
self.name = config.name
--- @type blink.cmp.Source
self.module = require(config[1]).new(config.opts or {})
self.config = config
Expand Down Expand Up @@ -45,13 +47,17 @@ function source:get_completions(context)
for _, item in ipairs(response.items) do
item.score_offset = (item.score_offset or 0) + (self.config.score_offset or 0)
item.cursor_column = context.cursor[2]
item.source = self.config[1]
item.source = self.name
end

self.last_response = require('blink.cmp.utils').shallow_copy(response)
self.last_response.is_cached = true
return response
end)
:catch(function(err)
vim.print('failed to get completions with error: ' .. err)
return { is_incomplete_forward = false, is_incomplete_backward = false, items = {} }
end)
end

--- @param response blink.cmp.CompletionResponse
Expand Down
Loading

0 comments on commit 1f0c0f3

Please sign in to comment.