Skip to content

Commit

Permalink
feat: add quick file selection upon tagging on Chat input (#533) bump…
Browse files Browse the repository at this point in the history
…:patch

* fix: improve inline citation logics without rag

* fix: improve explanation for citation options

* feat: add quick file selection on Chat input
  • Loading branch information
taprosoft authored Nov 28, 2024
1 parent f15abdb commit ab6b3fc
Show file tree
Hide file tree
Showing 10 changed files with 186 additions and 51 deletions.
59 changes: 40 additions & 19 deletions libs/kotaemon/kotaemon/indices/qa/citation_qa_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,20 @@ def answer_to_citations(self, answer) -> list[InlineEvidence]:
def replace_citation_with_link(self, answer: str):
# Define the regex pattern to match 【number】
pattern = r"【\d+】"

# Regular expression to match merged citations
multi_pattern = r"【([\d,\s]+)】"

# Function to replace merged citations with independent ones
def split_citations(match):
# Extract the numbers, split by comma, and create individual citations
numbers = match.group(1).split(",")
return "".join(f"【{num.strip()}】" for num in numbers)

# Replace merged citations in the text
answer = re.sub(multi_pattern, split_citations, answer)

# Find all citations in the answer
matches = re.finditer(pattern, answer)

matched_citations = set()
Expand Down Expand Up @@ -240,25 +254,30 @@ def mindmap_call():
# try streaming first
print("Trying LLM streaming")
for out_msg in self.llm.stream(messages):
if START_ANSWER in output:
if not final_answer:
try:
left_over_answer = output.split(START_ANSWER)[1].lstrip()
except IndexError:
left_over_answer = ""
if left_over_answer:
out_msg.text = left_over_answer + out_msg.text

final_answer += (
out_msg.text.lstrip() if not final_answer else out_msg.text
)
if evidence:
if START_ANSWER in output:
if not final_answer:
try:
left_over_answer = output.split(START_ANSWER)[
1
].lstrip()
except IndexError:
left_over_answer = ""
if left_over_answer:
out_msg.text = left_over_answer + out_msg.text

final_answer += (
out_msg.text.lstrip() if not final_answer else out_msg.text
)
yield Document(channel="chat", content=out_msg.text)

# check for the edge case of citation list is repeated
# with smaller LLMs
if START_CITATION in out_msg.text:
break
else:
yield Document(channel="chat", content=out_msg.text)

# check for the edge case of citation list is repeated
# with smaller LLMs
if START_CITATION in out_msg.text:
break

output += out_msg.text
logprobs += out_msg.logprobs
except NotImplementedError:
Expand Down Expand Up @@ -289,8 +308,10 @@ def mindmap_call():

# yield the final answer
final_answer = self.replace_citation_with_link(final_answer)
yield Document(channel="chat", content=None)
yield Document(channel="chat", content=final_answer)

if final_answer:
yield Document(channel="chat", content=None)
yield Document(channel="chat", content=final_answer)

return answer

Expand Down
3 changes: 3 additions & 0 deletions libs/kotaemon/kotaemon/indices/qa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ def find_start_end_phrase(
matches = []
matched_length = 0
for sentence in [start_phrase, end_phrase]:
if sentence is None:
continue

match = SequenceMatcher(
None, sentence, context, autojunk=False
).find_longest_match()
Expand Down
4 changes: 4 additions & 0 deletions libs/ktem/ktem/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ def make(self):
"<script>"
f"{self._svg_js}"
"</script>"
"<script type='module' "
"src='https://cdnjs.cloudflare.com/ajax/libs/tributejs/5.1.3/tribute.min.js'>" # noqa
"</script>"
"<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/tributejs/5.1.3/tribute.css'/>" # noqa
)

with gr.Blocks(
Expand Down
17 changes: 17 additions & 0 deletions libs/ktem/ktem/assets/css/main.css
Original file line number Diff line number Diff line change
Expand Up @@ -365,3 +365,20 @@ details.evidence {
color: #10b981;
text-decoration: none;
}

/* pop-up for file tag in chat input*/
.tribute-container ul {
background-color: var(--background-fill-primary) !important;
color: var(--body-text-color) !important;
font-family: var(--font);
font-size: var(--text-md);
}

.tribute-container li.highlight {
background-color: var(--border-color-primary) !important;
}

/* a fix for flickering background in Gradio DataFrame */
tbody:not(.row_odd) {
background: var(--table-even-background-fill);
}
67 changes: 45 additions & 22 deletions libs/ktem/ktem/index/file/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,25 @@
}
"""

update_file_list_js = """
function(file_list) {
var values = [];
for (var i = 0; i < file_list.length; i++) {
values.push({
key: file_list[i][0],
value: '"' + file_list[i][0] + '"',
});
}
var tribute = new Tribute({
values: values,
noMatchTemplate: "",
allowSpaces: true,
})
input_box = document.querySelector('#chat-input textarea');
tribute.attach(input_box);
}
"""


class File(gr.File):
"""Subclass from gr.File to maintain the original filename
Expand Down Expand Up @@ -1429,13 +1448,25 @@ def on_building_ui(self):
visible=False,
)
self.selector_user_id = gr.State(value=user_id)
self.selector_choices = gr.JSON(
value=[],
visible=False,
)

def on_register_events(self):
self.mode.change(
fn=lambda mode, user_id: (gr.update(visible=mode == "select"), user_id),
inputs=[self.mode, self._app.user_id],
outputs=[self.selector, self.selector_user_id],
)
# attach special event for the first index
if self._index.id == 1:
self.selector_choices.change(
fn=None,
inputs=[self.selector_choices],
js=update_file_list_js,
show_progress="hidden",
)

def as_gradio_component(self):
return [self.mode, self.selector, self.selector_user_id]
Expand Down Expand Up @@ -1468,7 +1499,7 @@ def load_files(self, selected_files, user_id):
available_ids = []
if user_id is None:
# not signed in
return gr.update(value=selected_files, choices=options)
return gr.update(value=selected_files, choices=options), options

with Session(engine) as session:
# get file list from Source table
Expand Down Expand Up @@ -1501,13 +1532,13 @@ def load_files(self, selected_files, user_id):
each for each in selected_files if each in available_ids_set
]

return gr.update(value=selected_files, choices=options)
return gr.update(value=selected_files, choices=options), options

def _on_app_created(self):
self._app.app.load(
self.load_files,
inputs=[self.selector, self._app.user_id],
outputs=[self.selector],
outputs=[self.selector, self.selector_choices],
)

def on_subscribe_public_events(self):
Expand All @@ -1516,26 +1547,18 @@ def on_subscribe_public_events(self):
definition={
"fn": self.load_files,
"inputs": [self.selector, self._app.user_id],
"outputs": [self.selector],
"outputs": [self.selector, self.selector_choices],
"show_progress": "hidden",
},
)
if self._app.f_user_management:
self._app.subscribe_event(
name="onSignIn",
definition={
"fn": self.load_files,
"inputs": [self.selector, self._app.user_id],
"outputs": [self.selector],
"show_progress": "hidden",
},
)
self._app.subscribe_event(
name="onSignOut",
definition={
"fn": self.load_files,
"inputs": [self.selector, self._app.user_id],
"outputs": [self.selector],
"show_progress": "hidden",
},
)
for event_name in ["onSignIn", "onSignOut"]:
self._app.subscribe_event(
name=event_name,
definition={
"fn": self.load_files,
"inputs": [self.selector, self._app.user_id],
"outputs": [self.selector, self.selector_choices],
"show_progress": "hidden",
},
)
63 changes: 56 additions & 7 deletions libs/ktem/ktem/pages/chat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ktem.app import BasePage
from ktem.components import reasonings
from ktem.db.models import Conversation, engine
from ktem.index.file.ui import File
from ktem.index.file.ui import File, chat_input_focus_js
from ktem.reasoning.prompt_optimization.suggest_conversation_name import (
SuggestConvNamePipeline,
)
Expand All @@ -22,7 +22,7 @@
from kotaemon.base import Document
from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS

from ...utils import SUPPORTED_LANGUAGE_MAP
from ...utils import SUPPORTED_LANGUAGE_MAP, get_file_names_regex
from .chat_panel import ChatPanel
from .common import STATE
from .control import ConversationControl
Expand Down Expand Up @@ -113,6 +113,7 @@ def on_building_ui(self):
self.state_plot_history = gr.State([])
self.state_plot_panel = gr.State(None)
self.state_follow_up = gr.State(None)
self.first_selector_choices = gr.State(None)

with gr.Column(scale=1, elem_id="conv-settings-panel") as self.conv_column:
self.chat_control = ConversationControl(self._app)
Expand All @@ -130,6 +131,11 @@ def on_building_ui(self):
):
index_ui.render()
gr_index = index_ui.as_gradio_component()

# get the file selector choices for the first index
if index_id == 0:
self.first_selector_choices = index_ui.selector_choices

if gr_index:
if isinstance(gr_index, list):
index.selector = tuple(
Expand Down Expand Up @@ -272,6 +278,7 @@ def on_register_events(self):
self.chat_control.conversation_id,
self.chat_control.conversation_rn,
self.state_follow_up,
self.first_selector_choices,
],
outputs=[
self.chat_panel.text_input,
Expand All @@ -280,6 +287,9 @@ def on_register_events(self):
self.chat_control.conversation,
self.chat_control.conversation_rn,
self.state_follow_up,
# file selector from the first index
self._indices_input[0],
self._indices_input[1],
],
concurrency_limit=20,
show_progress="hidden",
Expand Down Expand Up @@ -426,6 +436,10 @@ def on_register_events(self):
fn=self._json_to_plot,
inputs=self.state_plot_panel,
outputs=self.plot_panel,
).then(
fn=None,
inputs=None,
js=chat_input_focus_js,
)

self.chat_control.btn_del.click(
Expand Down Expand Up @@ -516,7 +530,12 @@ def on_register_events(self):
lambda: self.toggle_delete(""),
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
).then(
fn=None, inputs=None, outputs=None, js=pdfview_js
fn=lambda: True,
inputs=None,
outputs=[self._preview_links],
js=pdfview_js,
).then(
fn=None, inputs=None, outputs=None, js=chat_input_focus_js
)

# evidence display on message selection
Expand All @@ -535,7 +554,12 @@ def on_register_events(self):
inputs=self.state_plot_panel,
outputs=self.plot_panel,
).then(
fn=None, inputs=None, outputs=None, js=pdfview_js
fn=lambda: True,
inputs=None,
outputs=[self._preview_links],
js=pdfview_js,
).then(
fn=None, inputs=None, outputs=None, js=chat_input_focus_js
)

self.chat_control.cb_is_public.change(
Expand Down Expand Up @@ -585,14 +609,39 @@ def on_register_events(self):
)

def submit_msg(
self, chat_input, chat_history, user_id, conv_id, conv_name, chat_suggest
self,
chat_input,
chat_history,
user_id,
conv_id,
conv_name,
chat_suggest,
first_selector_choices,
):
"""Submit a message to the chatbot"""
if not chat_input:
raise ValueError("Input is empty")

chat_input_text = chat_input.get("text", "")

# get all file names with pattern @"filename" in input_str
file_names, chat_input_text = get_file_names_regex(chat_input_text)
first_selector_choices_map = {
item[0]: item[1] for item in first_selector_choices
}
file_ids = []

if file_names:
for file_name in file_names:
file_id = first_selector_choices_map.get(file_name)
if file_id:
file_ids.append(file_id)

if file_ids:
selector_output = ["select", file_ids]
else:
selector_output = [gr.update(), gr.update()]

# check if regen mode is active
if chat_input_text:
chat_history = chat_history + [(chat_input_text, None)]
Expand Down Expand Up @@ -620,14 +669,14 @@ def submit_msg(
new_conv_name = conv_name
new_chat_suggestion = chat_suggest

return (
return [
{},
chat_history,
new_conv_id,
conv_update,
new_conv_name,
new_chat_suggestion,
)
] + selector_output

def toggle_delete(self, conv_id):
if conv_id:
Expand Down
Loading

0 comments on commit ab6b3fc

Please sign in to comment.