Dupliactes workflow to separate POC from Prod and simpify retrieval
Browse files- app.py +47 -18
- climateqa/chat.py +4 -3
- climateqa/engine/chains/answer_rag.py +1 -1
- climateqa/engine/chains/query_transformation.py +1 -1
- climateqa/engine/chains/retrieve_documents.py +85 -58
- climateqa/engine/graph.py +99 -29
- climateqa/handle_stream_events.py +3 -3
- front/tabs/chat_interface.py +1 -1
- front/tabs/tab_examples.py +1 -1
- requirements.txt +1 -1
app.py
CHANGED
|
@@ -9,7 +9,7 @@ from climateqa.engine.embeddings import get_embeddings_function
|
|
| 9 |
from climateqa.engine.llm import get_llm
|
| 10 |
from climateqa.engine.vectorstore import get_pinecone_vectorstore
|
| 11 |
from climateqa.engine.reranker import get_reranker
|
| 12 |
-
from climateqa.engine.graph import make_graph_agent
|
| 13 |
from climateqa.engine.chains.retrieve_papers import find_papers
|
| 14 |
from climateqa.chat import start_chat, chat_stream, finish_chat
|
| 15 |
|
|
@@ -69,12 +69,19 @@ vectorstore_region = get_pinecone_vectorstore(embeddings_function, index_name=os
|
|
| 69 |
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
|
| 70 |
reranker = get_reranker("nano")
|
| 71 |
|
| 72 |
-
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0
|
|
|
|
| 73 |
|
| 74 |
|
| 75 |
async def chat(query, history, audience, sources, reports, relevant_content_sources_selection, search_only):
|
|
|
|
| 76 |
async for event in chat_stream(agent, query, history, audience, sources, reports, relevant_content_sources_selection, search_only, share_client, user_id):
|
| 77 |
yield event
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
# --------------------------------------------------------------------
|
|
@@ -205,7 +212,7 @@ def event_handling(
|
|
| 205 |
|
| 206 |
new_sources_hmtl = gr.State([])
|
| 207 |
|
| 208 |
-
|
| 209 |
|
| 210 |
for button in [config_button, close_config_modal]:
|
| 211 |
button.click(
|
|
@@ -213,18 +220,38 @@ def event_handling(
|
|
| 213 |
inputs=[config_open],
|
| 214 |
outputs=[config_modal, config_open]
|
| 215 |
)
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
|
| 229 |
new_sources_hmtl.change(lambda x : x, inputs = [new_sources_hmtl], outputs = [sources_textbox])
|
| 230 |
current_graphs.change(lambda x: x, inputs=[current_graphs], outputs=[graphs_container])
|
|
@@ -234,10 +261,12 @@ def event_handling(
|
|
| 234 |
for component in [sources_textbox, figures_cards, current_graphs, papers_html]:
|
| 235 |
component.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs, papers_html], [tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
|
| 236 |
|
| 237 |
-
|
| 238 |
# Search for papers
|
| 239 |
for component in [textbox, examples_hidden]:
|
| 240 |
component.submit(find_papers, [component, after, dropdown_external_sources], [papers_html, citations_network, papers_summary])
|
|
|
|
|
|
|
|
|
|
| 241 |
|
| 242 |
def main_ui():
|
| 243 |
# config_open = gr.State(True)
|
|
@@ -246,12 +275,12 @@ def main_ui():
|
|
| 246 |
|
| 247 |
with gr.Tabs():
|
| 248 |
cqa_components = cqa_tab(tab_name = "ClimateQ&A")
|
| 249 |
-
|
| 250 |
|
| 251 |
create_about_tab()
|
| 252 |
|
| 253 |
event_handling(cqa_components, config_components, tab_name = 'ClimateQ&A')
|
| 254 |
-
|
| 255 |
|
| 256 |
demo.queue()
|
| 257 |
|
|
|
|
| 9 |
from climateqa.engine.llm import get_llm
|
| 10 |
from climateqa.engine.vectorstore import get_pinecone_vectorstore
|
| 11 |
from climateqa.engine.reranker import get_reranker
|
| 12 |
+
from climateqa.engine.graph import make_graph_agent,make_graph_agent_poc
|
| 13 |
from climateqa.engine.chains.retrieve_papers import find_papers
|
| 14 |
from climateqa.chat import start_chat, chat_stream, finish_chat
|
| 15 |
|
|
|
|
| 69 |
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
|
| 70 |
reranker = get_reranker("nano")
|
| 71 |
|
| 72 |
+
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0.2)
|
| 73 |
+
agent_poc = make_graph_agent_poc(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0)#TODO put back default 0.2
|
| 74 |
|
| 75 |
|
| 76 |
async def chat(query, history, audience, sources, reports, relevant_content_sources_selection, search_only):
|
| 77 |
+
print("chat cqa - message received")
|
| 78 |
async for event in chat_stream(agent, query, history, audience, sources, reports, relevant_content_sources_selection, search_only, share_client, user_id):
|
| 79 |
yield event
|
| 80 |
+
|
| 81 |
+
async def chat_poc(query, history, audience, sources, reports, relevant_content_sources_selection, search_only):
|
| 82 |
+
print("chat poc - message received")
|
| 83 |
+
async for event in chat_stream(agent_poc, query, history, audience, sources, reports, relevant_content_sources_selection, search_only, share_client, user_id):
|
| 84 |
+
yield event
|
| 85 |
|
| 86 |
|
| 87 |
# --------------------------------------------------------------------
|
|
|
|
| 212 |
|
| 213 |
new_sources_hmtl = gr.State([])
|
| 214 |
|
| 215 |
+
print("textbox id : ", textbox.elem_id)
|
| 216 |
|
| 217 |
for button in [config_button, close_config_modal]:
|
| 218 |
button.click(
|
|
|
|
| 220 |
inputs=[config_open],
|
| 221 |
outputs=[config_modal, config_open]
|
| 222 |
)
|
| 223 |
+
|
| 224 |
+
if tab_name == "ClimateQ&A":
|
| 225 |
+
print("chat cqa - message sent")
|
| 226 |
+
|
| 227 |
+
# Event for textbox
|
| 228 |
+
(textbox
|
| 229 |
+
.submit(start_chat, [textbox, chatbot, search_only], [textbox, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{textbox.elem_id}")
|
| 230 |
+
.then(chat, [textbox, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{textbox.elem_id}")
|
| 231 |
+
.then(finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}")
|
| 232 |
+
)
|
| 233 |
+
# Event for examples_hidden
|
| 234 |
+
(examples_hidden
|
| 235 |
+
.change(start_chat, [examples_hidden, chatbot, search_only], [examples_hidden, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{examples_hidden.elem_id}")
|
| 236 |
+
.then(chat, [examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{examples_hidden.elem_id}")
|
| 237 |
+
.then(finish_chat, None, [textbox], api_name=f"finish_chat_{examples_hidden.elem_id}")
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
elif tab_name == "Beta - POC Adapt'Action":
|
| 241 |
+
print("chat poc - message sent")
|
| 242 |
+
# Event for textbox
|
| 243 |
+
(textbox
|
| 244 |
+
.submit(start_chat, [textbox, chatbot, search_only], [textbox, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{textbox.elem_id}")
|
| 245 |
+
.then(chat_poc, [textbox, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{textbox.elem_id}")
|
| 246 |
+
.then(finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}")
|
| 247 |
+
)
|
| 248 |
+
# Event for examples_hidden
|
| 249 |
+
(examples_hidden
|
| 250 |
+
.change(start_chat, [examples_hidden, chatbot, search_only], [examples_hidden, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{examples_hidden.elem_id}")
|
| 251 |
+
.then(chat_poc, [examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only], [chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs], concurrency_limit=8, api_name=f"chat_{examples_hidden.elem_id}")
|
| 252 |
+
.then(finish_chat, None, [textbox], api_name=f"finish_chat_{examples_hidden.elem_id}")
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
|
| 256 |
new_sources_hmtl.change(lambda x : x, inputs = [new_sources_hmtl], outputs = [sources_textbox])
|
| 257 |
current_graphs.change(lambda x: x, inputs=[current_graphs], outputs=[graphs_container])
|
|
|
|
| 261 |
for component in [sources_textbox, figures_cards, current_graphs, papers_html]:
|
| 262 |
component.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs, papers_html], [tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
|
| 263 |
|
|
|
|
| 264 |
# Search for papers
|
| 265 |
for component in [textbox, examples_hidden]:
|
| 266 |
component.submit(find_papers, [component, after, dropdown_external_sources], [papers_html, citations_network, papers_summary])
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
|
| 270 |
|
| 271 |
def main_ui():
|
| 272 |
# config_open = gr.State(True)
|
|
|
|
| 275 |
|
| 276 |
with gr.Tabs():
|
| 277 |
cqa_components = cqa_tab(tab_name = "ClimateQ&A")
|
| 278 |
+
local_cqa_components = cqa_tab(tab_name = "Beta - POC Adapt'Action")
|
| 279 |
|
| 280 |
create_about_tab()
|
| 281 |
|
| 282 |
event_handling(cqa_components, config_components, tab_name = 'ClimateQ&A')
|
| 283 |
+
event_handling(local_cqa_components, config_components, tab_name = 'Beta - POC Adapt\'Action')
|
| 284 |
|
| 285 |
demo.queue()
|
| 286 |
|
climateqa/chat.py
CHANGED
|
@@ -119,6 +119,7 @@ async def chat_stream(
|
|
| 119 |
start_streaming = False
|
| 120 |
graphs_html = ""
|
| 121 |
used_documents = []
|
|
|
|
| 122 |
answer_message_content = ""
|
| 123 |
|
| 124 |
# Define processing steps
|
|
@@ -138,8 +139,8 @@ async def chat_stream(
|
|
| 138 |
|
| 139 |
# Handle document retrieval
|
| 140 |
if event["event"] == "on_chain_end" and event["name"] in ["retrieve_documents","retrieve_local_data"] and event["data"]["output"] != None:
|
| 141 |
-
history, used_documents = handle_retrieved_documents(
|
| 142 |
-
event, history, used_documents
|
| 143 |
)
|
| 144 |
if event["event"] == "on_chain_end" and event["name"] == "answer_search" :
|
| 145 |
docs = event["data"]["input"]["documents"]
|
|
@@ -180,7 +181,7 @@ async def chat_stream(
|
|
| 180 |
# Handle query transformation
|
| 181 |
if event["name"] == "transform_query" and event["event"] == "on_chain_end":
|
| 182 |
if hasattr(history[-1], "content"):
|
| 183 |
-
sub_questions = [q["question"] for q in event["data"]["output"]["questions_list"]]
|
| 184 |
history[-1].content += "Decompose question into sub-questions:\n\n - " + "\n - ".join(sub_questions)
|
| 185 |
|
| 186 |
yield history, docs_html, output_query, output_language, related_contents, graphs_html
|
|
|
|
| 119 |
start_streaming = False
|
| 120 |
graphs_html = ""
|
| 121 |
used_documents = []
|
| 122 |
+
retrieved_contents = []
|
| 123 |
answer_message_content = ""
|
| 124 |
|
| 125 |
# Define processing steps
|
|
|
|
| 139 |
|
| 140 |
# Handle document retrieval
|
| 141 |
if event["event"] == "on_chain_end" and event["name"] in ["retrieve_documents","retrieve_local_data"] and event["data"]["output"] != None:
|
| 142 |
+
history, used_documents, retrieved_contents = handle_retrieved_documents(
|
| 143 |
+
event, history, used_documents, retrieved_contents
|
| 144 |
)
|
| 145 |
if event["event"] == "on_chain_end" and event["name"] == "answer_search" :
|
| 146 |
docs = event["data"]["input"]["documents"]
|
|
|
|
| 181 |
# Handle query transformation
|
| 182 |
if event["name"] == "transform_query" and event["event"] == "on_chain_end":
|
| 183 |
if hasattr(history[-1], "content"):
|
| 184 |
+
sub_questions = [q["question"] + "-> relevant sources : " + str(q["sources"]) for q in event["data"]["output"]["questions_list"]]
|
| 185 |
history[-1].content += "Decompose question into sub-questions:\n\n - " + "\n - ".join(sub_questions)
|
| 186 |
|
| 187 |
yield history, docs_html, output_query, output_language, related_contents, graphs_html
|
climateqa/engine/chains/answer_rag.py
CHANGED
|
@@ -61,7 +61,7 @@ def make_rag_node(llm,with_docs = True):
|
|
| 61 |
rag_chain = make_rag_chain(llm)
|
| 62 |
else:
|
| 63 |
rag_chain = make_rag_chain_without_docs(llm)
|
| 64 |
-
|
| 65 |
async def answer_rag(state,config):
|
| 66 |
print("---- Answer RAG ----")
|
| 67 |
start_time = time.time()
|
|
|
|
| 61 |
rag_chain = make_rag_chain(llm)
|
| 62 |
else:
|
| 63 |
rag_chain = make_rag_chain_without_docs(llm)
|
| 64 |
+
|
| 65 |
async def answer_rag(state,config):
|
| 66 |
print("---- Answer RAG ----")
|
| 67 |
start_time = time.time()
|
climateqa/engine/chains/query_transformation.py
CHANGED
|
@@ -60,7 +60,7 @@ from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
|
| 60 |
|
| 61 |
|
| 62 |
ROUTING_INDEX = {
|
| 63 |
-
"IPx":["IPCC", "
|
| 64 |
"POC": ["AcclimaTerra", "PCAET","Biodiv"],
|
| 65 |
"OpenAlex":["OpenAlex"],
|
| 66 |
}
|
|
|
|
| 60 |
|
| 61 |
|
| 62 |
ROUTING_INDEX = {
|
| 63 |
+
"IPx":["IPCC", "IPBES", "IPOS"],
|
| 64 |
"POC": ["AcclimaTerra", "PCAET","Biodiv"],
|
| 65 |
"OpenAlex":["OpenAlex"],
|
| 66 |
}
|
climateqa/engine/chains/retrieve_documents.py
CHANGED
|
@@ -15,7 +15,9 @@ from ..utils import log_event
|
|
| 15 |
from langchain_core.vectorstores import VectorStore
|
| 16 |
from typing import List
|
| 17 |
from langchain_core.documents.base import Document
|
|
|
|
| 18 |
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
def divide_into_parts(target, parts):
|
|
@@ -272,12 +274,27 @@ def concatenate_documents(index, source_type, docs_question_dict, k_by_question,
|
|
| 272 |
|
| 273 |
# The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
|
| 274 |
# @chain
|
| 275 |
-
async def retrieve_documents(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
"""
|
| 277 |
Unpack the first question of the remaining questions, and retrieve and rerank corresponding documents, based on the question and selected_sources
|
| 278 |
|
| 279 |
Args:
|
| 280 |
state (dict): The current state containing documents, related content, relevant content sources, remaining questions and n_questions.
|
|
|
|
| 281 |
config (dict): Configuration settings for logging and other purposes.
|
| 282 |
vectorstore (object): The vector store used to retrieve relevant documents.
|
| 283 |
reranker (object): The reranker used to rerank the retrieved documents.
|
|
@@ -290,35 +307,6 @@ async def retrieve_documents(state,config, source_type, vectorstore,reranker,llm
|
|
| 290 |
Returns:
|
| 291 |
dict: The updated state containing the retrieved and reranked documents, related content, and remaining questions.
|
| 292 |
"""
|
| 293 |
-
# TODO split les questions selon le type de sources dans le state question + conditions sur le nombre de questions traités par type de source
|
| 294 |
-
docs = state.get("documents", [])
|
| 295 |
-
related_content = state.get("related_content", [])
|
| 296 |
-
|
| 297 |
-
search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"]
|
| 298 |
-
search_only = state["search_only"]
|
| 299 |
-
|
| 300 |
-
reports = state["reports"]
|
| 301 |
-
|
| 302 |
-
# Get the current question
|
| 303 |
-
# current_question = state["questions_list"][0]
|
| 304 |
-
# remaining_questions = state["remaining_questions"][1:]
|
| 305 |
-
|
| 306 |
-
current_question_id = None
|
| 307 |
-
print("Questions Indexs", list(range(len(state["questions_list"]))), "- Handled questions : " ,state["handled_questions_index"])
|
| 308 |
-
|
| 309 |
-
for i in range(len(state["questions_list"])):
|
| 310 |
-
current_question = state["questions_list"][i]
|
| 311 |
-
|
| 312 |
-
if i not in state["handled_questions_index"] and current_question["source_type"] == source_type:
|
| 313 |
-
current_question_id = i
|
| 314 |
-
break
|
| 315 |
-
|
| 316 |
-
# TODO filter on source_type
|
| 317 |
-
|
| 318 |
-
k_by_question = k_final // state["n_questions"]["total"]
|
| 319 |
-
k_summary_by_question = _get_k_summary_by_question(state["n_questions"]["total"])
|
| 320 |
-
k_images_by_question = _get_k_images_by_question(state["n_questions"]["total"])
|
| 321 |
-
|
| 322 |
sources = current_question["sources"]
|
| 323 |
question = current_question["question"]
|
| 324 |
index = current_question["index"]
|
|
@@ -329,8 +317,7 @@ async def retrieve_documents(state,config, source_type, vectorstore,reranker,llm
|
|
| 329 |
|
| 330 |
print(f"""---- Retrieve documents from {current_question["source_type"]}----""")
|
| 331 |
|
| 332 |
-
|
| 333 |
-
if source_type == "IPx": # always true for now #TODO rename to IPx
|
| 334 |
docs_question_dict = await get_IPCC_relevant_documents(
|
| 335 |
query = question,
|
| 336 |
vectorstore=vectorstore,
|
|
@@ -359,7 +346,6 @@ async def retrieve_documents(state,config, source_type, vectorstore,reranker,llm
|
|
| 359 |
k_images= k_by_question
|
| 360 |
)
|
| 361 |
|
| 362 |
-
|
| 363 |
# Rerank
|
| 364 |
if reranker is not None and rerank_by_question:
|
| 365 |
with suppress_output():
|
|
@@ -381,35 +367,72 @@ async def retrieve_documents(state,config, source_type, vectorstore,reranker,llm
|
|
| 381 |
docs_question = _add_sources_used_in_metadata(docs_question,sources,question,index)
|
| 382 |
images_question = _add_sources_used_in_metadata(images_question,sources,question,index)
|
| 383 |
|
| 384 |
-
|
| 385 |
-
# docs.extend(docs_question)
|
| 386 |
-
# related_content.extend(images_question)
|
| 387 |
-
docs = docs_question
|
| 388 |
-
related_content = images_question
|
| 389 |
-
new_state = {"documents":docs, "related_contents": related_content, "handled_questions_index": [current_question_id]}
|
| 390 |
-
print("Updated state with question ", current_question_id, " added ", len(docs), " documents")
|
| 391 |
-
return new_state
|
| 392 |
|
| 393 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
|
| 395 |
def make_IPx_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
|
| 396 |
|
| 397 |
@chain
|
| 398 |
async def retrieve_IPx_docs(state, config):
|
| 399 |
source_type = "IPx"
|
|
|
|
|
|
|
| 400 |
# return {"documents":[], "related_contents": [], "handled_questions_index": list(range(len(state["questions_list"])))} # TODO Remove
|
| 401 |
|
| 402 |
-
state =
|
| 403 |
-
state
|
| 404 |
-
config=
|
| 405 |
source_type=source_type,
|
|
|
|
| 406 |
vectorstore=vectorstore,
|
| 407 |
-
reranker=
|
| 408 |
-
llm=llm,
|
| 409 |
rerank_by_question=rerank_by_question,
|
| 410 |
-
k_final=k_final,
|
| 411 |
-
k_before_reranking=k_before_reranking,
|
| 412 |
-
k_summary=k_summary
|
| 413 |
)
|
| 414 |
return state
|
| 415 |
|
|
@@ -420,19 +443,23 @@ def make_POC_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_
|
|
| 420 |
|
| 421 |
@chain
|
| 422 |
async def retrieve_POC_docs_node(state, config):
|
|
|
|
|
|
|
|
|
|
| 423 |
source_type = "POC"
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
|
|
|
|
|
|
| 427 |
source_type=source_type,
|
|
|
|
| 428 |
vectorstore=vectorstore,
|
| 429 |
-
reranker=
|
| 430 |
-
llm=llm,
|
| 431 |
rerank_by_question=rerank_by_question,
|
| 432 |
-
k_final=k_final,
|
| 433 |
-
k_before_reranking=k_before_reranking,
|
| 434 |
-
|
| 435 |
-
)
|
| 436 |
return state
|
| 437 |
|
| 438 |
return retrieve_POC_docs_node
|
|
|
|
| 15 |
from langchain_core.vectorstores import VectorStore
|
| 16 |
from typing import List
|
| 17 |
from langchain_core.documents.base import Document
|
| 18 |
+
import asyncio
|
| 19 |
|
| 20 |
+
from typing import Any, Dict, List, Tuple
|
| 21 |
|
| 22 |
|
| 23 |
def divide_into_parts(target, parts):
|
|
|
|
| 274 |
|
| 275 |
# The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
|
| 276 |
# @chain
|
| 277 |
+
async def retrieve_documents(
|
| 278 |
+
current_question: Dict[str, Any],
|
| 279 |
+
config: Dict[str, Any],
|
| 280 |
+
source_type: str,
|
| 281 |
+
vectorstore: VectorStore,
|
| 282 |
+
reranker: Any,
|
| 283 |
+
search_figures: bool = False,
|
| 284 |
+
search_only: bool = False,
|
| 285 |
+
reports: list = [],
|
| 286 |
+
rerank_by_question: bool = True,
|
| 287 |
+
k_images_by_question: int = 5,
|
| 288 |
+
k_before_reranking: int = 100,
|
| 289 |
+
k_by_question: int = 5,
|
| 290 |
+
k_summary_by_question: int = 3
|
| 291 |
+
) -> Tuple[List[Document], List[Document]]:
|
| 292 |
"""
|
| 293 |
Unpack the first question of the remaining questions, and retrieve and rerank corresponding documents, based on the question and selected_sources
|
| 294 |
|
| 295 |
Args:
|
| 296 |
state (dict): The current state containing documents, related content, relevant content sources, remaining questions and n_questions.
|
| 297 |
+
current_question (dict): The current question being processed.
|
| 298 |
config (dict): Configuration settings for logging and other purposes.
|
| 299 |
vectorstore (object): The vector store used to retrieve relevant documents.
|
| 300 |
reranker (object): The reranker used to rerank the retrieved documents.
|
|
|
|
| 307 |
Returns:
|
| 308 |
dict: The updated state containing the retrieved and reranked documents, related content, and remaining questions.
|
| 309 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
sources = current_question["sources"]
|
| 311 |
question = current_question["question"]
|
| 312 |
index = current_question["index"]
|
|
|
|
| 317 |
|
| 318 |
print(f"""---- Retrieve documents from {current_question["source_type"]}----""")
|
| 319 |
|
| 320 |
+
if source_type == "IPx":
|
|
|
|
| 321 |
docs_question_dict = await get_IPCC_relevant_documents(
|
| 322 |
query = question,
|
| 323 |
vectorstore=vectorstore,
|
|
|
|
| 346 |
k_images= k_by_question
|
| 347 |
)
|
| 348 |
|
|
|
|
| 349 |
# Rerank
|
| 350 |
if reranker is not None and rerank_by_question:
|
| 351 |
with suppress_output():
|
|
|
|
| 367 |
docs_question = _add_sources_used_in_metadata(docs_question,sources,question,index)
|
| 368 |
images_question = _add_sources_used_in_metadata(images_question,sources,question,index)
|
| 369 |
|
| 370 |
+
return docs_question, images_question
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
|
| 372 |
|
| 373 |
+
async def retrieve_documents_for_all_questions(state, config, source_type, to_handle_questions_index, vectorstore, reranker, rerank_by_question=True, k_final=15, k_before_reranking=100):
|
| 374 |
+
"""
|
| 375 |
+
Retrieve documents in parallel for all questions.
|
| 376 |
+
"""
|
| 377 |
+
# to_handle_questions_index = [x for x in state["questions_list"] if x["source_type"] == "IPx"]
|
| 378 |
+
|
| 379 |
+
# TODO split les questions selon le type de sources dans le state question + conditions sur le nombre de questions traités par type de source
|
| 380 |
+
docs = state.get("documents", [])
|
| 381 |
+
related_content = state.get("related_content", [])
|
| 382 |
+
search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"]
|
| 383 |
+
search_only = state["search_only"]
|
| 384 |
+
reports = state["reports"]
|
| 385 |
+
|
| 386 |
+
k_by_question = k_final // state["n_questions"]["total"]
|
| 387 |
+
k_summary_by_question = _get_k_summary_by_question(state["n_questions"]["total"])
|
| 388 |
+
k_images_by_question = _get_k_images_by_question(state["n_questions"]["total"])
|
| 389 |
+
k_before_reranking=100
|
| 390 |
+
|
| 391 |
+
tasks = [
|
| 392 |
+
retrieve_documents(
|
| 393 |
+
current_question=question,
|
| 394 |
+
config=config,
|
| 395 |
+
source_type=source_type,
|
| 396 |
+
vectorstore=vectorstore,
|
| 397 |
+
reranker=reranker,
|
| 398 |
+
search_figures=search_figures,
|
| 399 |
+
search_only=search_only,
|
| 400 |
+
reports=reports,
|
| 401 |
+
rerank_by_question=rerank_by_question,
|
| 402 |
+
k_images_by_question=k_images_by_question,
|
| 403 |
+
k_before_reranking=k_before_reranking,
|
| 404 |
+
k_by_question=k_by_question,
|
| 405 |
+
k_summary_by_question=k_summary_by_question
|
| 406 |
+
)
|
| 407 |
+
for i, question in enumerate(state["questions_list"]) if i in to_handle_questions_index
|
| 408 |
+
]
|
| 409 |
+
results = await asyncio.gather(*tasks)
|
| 410 |
+
# Combine results
|
| 411 |
+
new_state = {"documents": [], "related_contents": [], "handled_questions_index": to_handle_questions_index}
|
| 412 |
+
for docs_question, images_question in results:
|
| 413 |
+
new_state["documents"].extend(docs_question)
|
| 414 |
+
new_state["related_contents"].extend(images_question)
|
| 415 |
+
return new_state
|
| 416 |
|
| 417 |
def make_IPx_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
|
| 418 |
|
| 419 |
@chain
|
| 420 |
async def retrieve_IPx_docs(state, config):
|
| 421 |
source_type = "IPx"
|
| 422 |
+
IPx_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "IPx"]
|
| 423 |
+
|
| 424 |
# return {"documents":[], "related_contents": [], "handled_questions_index": list(range(len(state["questions_list"])))} # TODO Remove
|
| 425 |
|
| 426 |
+
state = await retrieve_documents_for_all_questions(
|
| 427 |
+
state=state,
|
| 428 |
+
config=config,
|
| 429 |
source_type=source_type,
|
| 430 |
+
to_handle_questions_index=IPx_questions_index,
|
| 431 |
vectorstore=vectorstore,
|
| 432 |
+
reranker=reranker,
|
|
|
|
| 433 |
rerank_by_question=rerank_by_question,
|
| 434 |
+
k_final=k_final,
|
| 435 |
+
k_before_reranking=k_before_reranking,
|
|
|
|
| 436 |
)
|
| 437 |
return state
|
| 438 |
|
|
|
|
| 443 |
|
| 444 |
@chain
|
| 445 |
async def retrieve_POC_docs_node(state, config):
|
| 446 |
+
if "POC region" not in state["relevant_content_sources_selection"] :
|
| 447 |
+
return {}
|
| 448 |
+
|
| 449 |
source_type = "POC"
|
| 450 |
+
POC_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
|
| 451 |
+
|
| 452 |
+
state = await retrieve_documents_for_all_questions(
|
| 453 |
+
state=state,
|
| 454 |
+
config=config,
|
| 455 |
source_type=source_type,
|
| 456 |
+
to_handle_questions_index=POC_questions_index,
|
| 457 |
vectorstore=vectorstore,
|
| 458 |
+
reranker=reranker,
|
|
|
|
| 459 |
rerank_by_question=rerank_by_question,
|
| 460 |
+
k_final=k_final,
|
| 461 |
+
k_before_reranking=k_before_reranking,
|
| 462 |
+
)
|
|
|
|
| 463 |
return state
|
| 464 |
|
| 465 |
return retrieve_POC_docs_node
|
climateqa/engine/graph.py
CHANGED
|
@@ -95,10 +95,10 @@ def route_based_on_relevant_docs(state,threshold_docs=0.2):
|
|
| 95 |
def route_continue_retrieve_documents(state):
|
| 96 |
index_question_ipx = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "IPx"]
|
| 97 |
questions_ipx_finished = all(elem in state["handled_questions_index"] for elem in index_question_ipx)
|
| 98 |
-
if questions_ipx_finished and state["search_only"]:
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
return "
|
| 102 |
else:
|
| 103 |
return "retrieve_documents"
|
| 104 |
|
|
@@ -113,10 +113,10 @@ def route_continue_retrieve_documents(state):
|
|
| 113 |
def route_continue_retrieve_local_documents(state):
|
| 114 |
index_question_poc = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
|
| 115 |
questions_poc_finished = all(elem in state["handled_questions_index"] for elem in index_question_poc)
|
| 116 |
-
if questions_poc_finished and state["search_only"]:
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
return "
|
| 120 |
else:
|
| 121 |
return "retrieve_local_data"
|
| 122 |
|
|
@@ -139,8 +139,7 @@ def route_retrieve_documents(state):
|
|
| 139 |
|
| 140 |
if "Graphs (OurWorldInData)" in state["relevant_content_sources_selection"] :
|
| 141 |
sources_to_retrieve.append("retrieve_graphs")
|
| 142 |
-
|
| 143 |
-
sources_to_retrieve.append("retrieve_local_data")
|
| 144 |
if sources_to_retrieve == []:
|
| 145 |
return END
|
| 146 |
return sources_to_retrieve
|
|
@@ -160,7 +159,7 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
|
|
| 160 |
answer_ai_impact = make_ai_impact_node(llm)
|
| 161 |
retrieve_documents = make_IPx_retriever_node(vectorstore_ipcc, reranker, llm)
|
| 162 |
retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker)
|
| 163 |
-
retrieve_local_data = make_POC_retriever_node(vectorstore_region, reranker, llm)
|
| 164 |
answer_rag = make_rag_node(llm, with_docs=True)
|
| 165 |
answer_rag_no_docs = make_rag_node(llm, with_docs=False)
|
| 166 |
chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
|
|
@@ -175,7 +174,7 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
|
|
| 175 |
workflow.add_node("answer_chitchat", answer_chitchat)
|
| 176 |
workflow.add_node("chitchat_categorize_intent", chitchat_categorize_intent)
|
| 177 |
workflow.add_node("retrieve_graphs", retrieve_graphs)
|
| 178 |
-
workflow.add_node("retrieve_local_data", retrieve_local_data)
|
| 179 |
workflow.add_node("retrieve_graphs_chitchat", retrieve_graphs)
|
| 180 |
workflow.add_node("retrieve_documents", retrieve_documents)
|
| 181 |
workflow.add_node("answer_rag", answer_rag)
|
|
@@ -202,17 +201,92 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
|
|
| 202 |
route_translation,
|
| 203 |
make_id_dict(["translate_query","transform_query"])
|
| 204 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
workflow.add_conditional_edges(
|
| 206 |
-
"
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
make_id_dict([END,"retrieve_documents","answer_search"])
|
| 210 |
)
|
|
|
|
| 211 |
workflow.add_conditional_edges(
|
| 212 |
-
"
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
make_id_dict([END,"retrieve_local_data","answer_search"])
|
| 216 |
)
|
| 217 |
|
| 218 |
workflow.add_conditional_edges(
|
|
@@ -223,19 +297,13 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
|
|
| 223 |
workflow.add_conditional_edges(
|
| 224 |
"transform_query",
|
| 225 |
route_retrieve_documents,
|
| 226 |
-
make_id_dict(["retrieve_graphs",
|
| 227 |
)
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
# workflow.add_conditional_edges(
|
| 231 |
-
# "transform_query",
|
| 232 |
-
# lambda state : "retrieve_graphs" if "POC region" in state["relevant_content_sources_selection"] else END,
|
| 233 |
-
# make_id_dict(["retrieve_local_data", END])
|
| 234 |
-
# )
|
| 235 |
|
| 236 |
# Define the edges
|
| 237 |
workflow.add_edge("translate_query", "transform_query")
|
| 238 |
workflow.add_edge("transform_query", "retrieve_documents") #TODO put back
|
|
|
|
| 239 |
# workflow.add_edge("transform_query", END) # TODO remove
|
| 240 |
|
| 241 |
workflow.add_edge("retrieve_graphs", END)
|
|
@@ -243,7 +311,9 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
|
|
| 243 |
workflow.add_edge("answer_rag_no_docs", END)
|
| 244 |
workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
|
| 245 |
workflow.add_edge("retrieve_graphs_chitchat", END)
|
| 246 |
-
|
|
|
|
|
|
|
| 247 |
|
| 248 |
# Compile
|
| 249 |
app = workflow.compile()
|
|
|
|
| 95 |
def route_continue_retrieve_documents(state):
|
| 96 |
index_question_ipx = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "IPx"]
|
| 97 |
questions_ipx_finished = all(elem in state["handled_questions_index"] for elem in index_question_ipx)
|
| 98 |
+
# if questions_ipx_finished and state["search_only"]:
|
| 99 |
+
# return END
|
| 100 |
+
if questions_ipx_finished:
|
| 101 |
+
return "end_retrieve_IPx_documents"
|
| 102 |
else:
|
| 103 |
return "retrieve_documents"
|
| 104 |
|
|
|
|
| 113 |
def route_continue_retrieve_local_documents(state):
|
| 114 |
index_question_poc = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
|
| 115 |
questions_poc_finished = all(elem in state["handled_questions_index"] for elem in index_question_poc)
|
| 116 |
+
# if questions_poc_finished and state["search_only"]:
|
| 117 |
+
# return END
|
| 118 |
+
if questions_poc_finished or ("POC region" not in state["relevant_content_sources_selection"]):
|
| 119 |
+
return "end_retrieve_local_documents"
|
| 120 |
else:
|
| 121 |
return "retrieve_local_data"
|
| 122 |
|
|
|
|
| 139 |
|
| 140 |
if "Graphs (OurWorldInData)" in state["relevant_content_sources_selection"] :
|
| 141 |
sources_to_retrieve.append("retrieve_graphs")
|
| 142 |
+
|
|
|
|
| 143 |
if sources_to_retrieve == []:
|
| 144 |
return END
|
| 145 |
return sources_to_retrieve
|
|
|
|
| 159 |
answer_ai_impact = make_ai_impact_node(llm)
|
| 160 |
retrieve_documents = make_IPx_retriever_node(vectorstore_ipcc, reranker, llm)
|
| 161 |
retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker)
|
| 162 |
+
# retrieve_local_data = make_POC_retriever_node(vectorstore_region, reranker, llm)
|
| 163 |
answer_rag = make_rag_node(llm, with_docs=True)
|
| 164 |
answer_rag_no_docs = make_rag_node(llm, with_docs=False)
|
| 165 |
chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
|
|
|
|
| 174 |
workflow.add_node("answer_chitchat", answer_chitchat)
|
| 175 |
workflow.add_node("chitchat_categorize_intent", chitchat_categorize_intent)
|
| 176 |
workflow.add_node("retrieve_graphs", retrieve_graphs)
|
| 177 |
+
# workflow.add_node("retrieve_local_data", retrieve_local_data)
|
| 178 |
workflow.add_node("retrieve_graphs_chitchat", retrieve_graphs)
|
| 179 |
workflow.add_node("retrieve_documents", retrieve_documents)
|
| 180 |
workflow.add_node("answer_rag", answer_rag)
|
|
|
|
| 201 |
route_translation,
|
| 202 |
make_id_dict(["translate_query","transform_query"])
|
| 203 |
)
|
| 204 |
+
|
| 205 |
+
workflow.add_conditional_edges(
|
| 206 |
+
"answer_search",
|
| 207 |
+
lambda x : route_based_on_relevant_docs(x,threshold_docs=threshold_docs),
|
| 208 |
+
make_id_dict(["answer_rag","answer_rag_no_docs"])
|
| 209 |
+
)
|
| 210 |
+
workflow.add_conditional_edges(
|
| 211 |
+
"transform_query",
|
| 212 |
+
route_retrieve_documents,
|
| 213 |
+
make_id_dict(["retrieve_graphs", END])
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
# Define the edges
|
| 217 |
+
workflow.add_edge("translate_query", "transform_query")
|
| 218 |
+
workflow.add_edge("transform_query", "retrieve_documents") #TODO put back
|
| 219 |
+
# workflow.add_edge("transform_query", "retrieve_local_data")
|
| 220 |
+
# workflow.add_edge("transform_query", END) # TODO remove
|
| 221 |
+
|
| 222 |
+
workflow.add_edge("retrieve_graphs", END)
|
| 223 |
+
workflow.add_edge("answer_rag", END)
|
| 224 |
+
workflow.add_edge("answer_rag_no_docs", END)
|
| 225 |
+
workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
|
| 226 |
+
workflow.add_edge("retrieve_graphs_chitchat", END)
|
| 227 |
+
|
| 228 |
+
# workflow.add_edge("retrieve_local_data", "answer_search")
|
| 229 |
+
workflow.add_edge("retrieve_documents", "answer_search")
|
| 230 |
+
|
| 231 |
+
# Compile
|
| 232 |
+
app = workflow.compile()
|
| 233 |
+
return app
|
| 234 |
+
|
| 235 |
+
def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_region, reranker, threshold_docs=0.2):
|
| 236 |
+
|
| 237 |
+
workflow = StateGraph(GraphState)
|
| 238 |
+
|
| 239 |
+
# Define the node functions
|
| 240 |
+
categorize_intent = make_intent_categorization_node(llm)
|
| 241 |
+
transform_query = make_query_transform_node(llm)
|
| 242 |
+
translate_query = make_translation_node(llm)
|
| 243 |
+
answer_chitchat = make_chitchat_node(llm)
|
| 244 |
+
answer_ai_impact = make_ai_impact_node(llm)
|
| 245 |
+
retrieve_documents = make_IPx_retriever_node(vectorstore_ipcc, reranker, llm)
|
| 246 |
+
retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker)
|
| 247 |
+
retrieve_local_data = make_POC_retriever_node(vectorstore_region, reranker, llm)
|
| 248 |
+
answer_rag = make_rag_node(llm, with_docs=True)
|
| 249 |
+
answer_rag_no_docs = make_rag_node(llm, with_docs=False)
|
| 250 |
+
chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
|
| 251 |
+
|
| 252 |
+
# Define the nodes
|
| 253 |
+
# workflow.add_node("set_defaults", set_defaults)
|
| 254 |
+
workflow.add_node("categorize_intent", categorize_intent)
|
| 255 |
+
workflow.add_node("answer_climate", dummy)
|
| 256 |
+
workflow.add_node("answer_search", answer_search)
|
| 257 |
+
# workflow.add_node("end_retrieve_local_documents", dummy)
|
| 258 |
+
# workflow.add_node("end_retrieve_IPx_documents", dummy)
|
| 259 |
+
workflow.add_node("transform_query", transform_query)
|
| 260 |
+
workflow.add_node("translate_query", translate_query)
|
| 261 |
+
workflow.add_node("answer_chitchat", answer_chitchat)
|
| 262 |
+
workflow.add_node("chitchat_categorize_intent", chitchat_categorize_intent)
|
| 263 |
+
workflow.add_node("retrieve_graphs", retrieve_graphs)
|
| 264 |
+
workflow.add_node("retrieve_local_data", retrieve_local_data)
|
| 265 |
+
workflow.add_node("retrieve_graphs_chitchat", retrieve_graphs)
|
| 266 |
+
workflow.add_node("retrieve_documents", retrieve_documents)
|
| 267 |
+
workflow.add_node("answer_rag", answer_rag)
|
| 268 |
+
workflow.add_node("answer_rag_no_docs", answer_rag_no_docs)
|
| 269 |
+
|
| 270 |
+
# Entry point
|
| 271 |
+
workflow.set_entry_point("categorize_intent")
|
| 272 |
+
|
| 273 |
+
# CONDITIONAL EDGES
|
| 274 |
+
workflow.add_conditional_edges(
|
| 275 |
+
"categorize_intent",
|
| 276 |
+
route_intent,
|
| 277 |
+
make_id_dict(["answer_chitchat","answer_climate"])
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
workflow.add_conditional_edges(
|
| 281 |
+
"chitchat_categorize_intent",
|
| 282 |
+
chitchat_route_intent,
|
| 283 |
+
make_id_dict(["retrieve_graphs_chitchat", END])
|
|
|
|
| 284 |
)
|
| 285 |
+
|
| 286 |
workflow.add_conditional_edges(
|
| 287 |
+
"answer_climate",
|
| 288 |
+
route_translation,
|
| 289 |
+
make_id_dict(["translate_query","transform_query"])
|
|
|
|
| 290 |
)
|
| 291 |
|
| 292 |
workflow.add_conditional_edges(
|
|
|
|
| 297 |
workflow.add_conditional_edges(
|
| 298 |
"transform_query",
|
| 299 |
route_retrieve_documents,
|
| 300 |
+
make_id_dict(["retrieve_graphs", END])
|
| 301 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
|
| 303 |
# Define the edges
|
| 304 |
workflow.add_edge("translate_query", "transform_query")
|
| 305 |
workflow.add_edge("transform_query", "retrieve_documents") #TODO put back
|
| 306 |
+
workflow.add_edge("transform_query", "retrieve_local_data")
|
| 307 |
# workflow.add_edge("transform_query", END) # TODO remove
|
| 308 |
|
| 309 |
workflow.add_edge("retrieve_graphs", END)
|
|
|
|
| 311 |
workflow.add_edge("answer_rag_no_docs", END)
|
| 312 |
workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
|
| 313 |
workflow.add_edge("retrieve_graphs_chitchat", END)
|
| 314 |
+
|
| 315 |
+
workflow.add_edge("retrieve_local_data", "answer_search")
|
| 316 |
+
workflow.add_edge("retrieve_documents", "answer_search")
|
| 317 |
|
| 318 |
# Compile
|
| 319 |
app = workflow.compile()
|
climateqa/handle_stream_events.py
CHANGED
|
@@ -22,7 +22,7 @@ def convert_to_docs_to_html(docs: list[dict]) -> str:
|
|
| 22 |
docs_html.append(make_html_source(d, i))
|
| 23 |
return "".join(docs_html)
|
| 24 |
|
| 25 |
-
def handle_retrieved_documents(event: StreamEvent, history : list[ChatMessage], used_documents : list[str]) -> tuple[str, list[ChatMessage], list[str]]:
|
| 26 |
"""
|
| 27 |
Handles the retrieved documents and returns the HTML representation of the documents
|
| 28 |
|
|
@@ -35,7 +35,7 @@ def handle_retrieved_documents(event: StreamEvent, history : list[ChatMessage],
|
|
| 35 |
tuple[str, list[ChatMessage], list[str]]: The updated HTML representation of the documents, the updated message history and the updated list of used documents
|
| 36 |
"""
|
| 37 |
if "documents" not in event["data"]["output"] or event["data"]["output"]["documents"] == []:
|
| 38 |
-
return history, used_documents
|
| 39 |
|
| 40 |
try:
|
| 41 |
docs = event["data"]["output"]["documents"]
|
|
@@ -49,7 +49,7 @@ def handle_retrieved_documents(event: StreamEvent, history : list[ChatMessage],
|
|
| 49 |
except Exception as e:
|
| 50 |
print(f"Error getting documents: {e}")
|
| 51 |
print(event)
|
| 52 |
-
return history, used_documents
|
| 53 |
|
| 54 |
def stream_answer(history: list[ChatMessage], event : StreamEvent, start_streaming : bool, answer_message_content : str)-> tuple[list[ChatMessage], bool, str]:
|
| 55 |
"""
|
|
|
|
| 22 |
docs_html.append(make_html_source(d, i))
|
| 23 |
return "".join(docs_html)
|
| 24 |
|
| 25 |
+
def handle_retrieved_documents(event: StreamEvent, history : list[ChatMessage], used_documents : list[str],related_content:list[str]) -> tuple[str, list[ChatMessage], list[str]]:
|
| 26 |
"""
|
| 27 |
Handles the retrieved documents and returns the HTML representation of the documents
|
| 28 |
|
|
|
|
| 35 |
tuple[str, list[ChatMessage], list[str]]: The updated HTML representation of the documents, the updated message history and the updated list of used documents
|
| 36 |
"""
|
| 37 |
if "documents" not in event["data"]["output"] or event["data"]["output"]["documents"] == []:
|
| 38 |
+
return history, used_documents, related_content
|
| 39 |
|
| 40 |
try:
|
| 41 |
docs = event["data"]["output"]["documents"]
|
|
|
|
| 49 |
except Exception as e:
|
| 50 |
print(f"Error getting documents: {e}")
|
| 51 |
print(event)
|
| 52 |
+
return history, used_documents, related_content
|
| 53 |
|
| 54 |
def stream_answer(history: list[ChatMessage], event : StreamEvent, start_streaming : bool, answer_message_content : str)-> tuple[list[ChatMessage], bool, str]:
|
| 55 |
"""
|
front/tabs/chat_interface.py
CHANGED
|
@@ -44,7 +44,7 @@ def create_chat_interface():
|
|
| 44 |
scale=12,
|
| 45 |
lines=1,
|
| 46 |
interactive=True,
|
| 47 |
-
elem_id="input-textbox"
|
| 48 |
)
|
| 49 |
|
| 50 |
config_button = gr.Button("", elem_id="config-button")
|
|
|
|
| 44 |
scale=12,
|
| 45 |
lines=1,
|
| 46 |
interactive=True,
|
| 47 |
+
elem_id=f"input-textbox"
|
| 48 |
)
|
| 49 |
|
| 50 |
config_button = gr.Button("", elem_id="config-button")
|
front/tabs/tab_examples.py
CHANGED
|
@@ -3,7 +3,7 @@ from climateqa.sample_questions import QUESTIONS
|
|
| 3 |
|
| 4 |
|
| 5 |
def create_examples_tab():
|
| 6 |
-
examples_hidden = gr.Textbox(visible=False)
|
| 7 |
first_key = list(QUESTIONS.keys())[0]
|
| 8 |
dropdown_samples = gr.Dropdown(
|
| 9 |
choices=QUESTIONS.keys(),
|
|
|
|
| 3 |
|
| 4 |
|
| 5 |
def create_examples_tab():
|
| 6 |
+
examples_hidden = gr.Textbox(visible=False, elem_id=f"examples-hidden")
|
| 7 |
first_key = list(QUESTIONS.keys())[0]
|
| 8 |
dropdown_samples = gr.Dropdown(
|
| 9 |
choices=QUESTIONS.keys(),
|
requirements.txt
CHANGED
|
@@ -4,7 +4,7 @@ azure-storage-blob
|
|
| 4 |
python-dotenv==1.0.0
|
| 5 |
langchain==0.2.1
|
| 6 |
langchain_openai==0.1.7
|
| 7 |
-
langgraph==0.
|
| 8 |
pinecone-client==4.1.0
|
| 9 |
sentence-transformers==2.6.0
|
| 10 |
huggingface-hub
|
|
|
|
| 4 |
python-dotenv==1.0.0
|
| 5 |
langchain==0.2.1
|
| 6 |
langchain_openai==0.1.7
|
| 7 |
+
langgraph==0.2.70
|
| 8 |
pinecone-client==4.1.0
|
| 9 |
sentence-transformers==2.6.0
|
| 10 |
huggingface-hub
|