Multiple sources OK
Browse files
climateqa/engine/chains/answer_rag.py
CHANGED
|
@@ -11,7 +11,7 @@ import time
|
|
| 11 |
from ..utils import rename_chain, pass_values
|
| 12 |
|
| 13 |
|
| 14 |
-
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
|
| 15 |
|
| 16 |
def _combine_documents(
|
| 17 |
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, sep="\n\n"
|
|
|
|
| 11 |
from ..utils import rename_chain, pass_values
|
| 12 |
|
| 13 |
|
| 14 |
+
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="Source : {source} - {page_content}")
|
| 15 |
|
| 16 |
def _combine_documents(
|
| 17 |
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, sep="\n\n"
|
climateqa/engine/chains/prompts.py
CHANGED
|
@@ -36,6 +36,30 @@ You are given a question and extracted passages of the IPCC and/or IPBES reports
|
|
| 36 |
"""
|
| 37 |
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
answer_prompt_template = """
|
| 40 |
You are ClimateQ&A, an AI Assistant created by Ekimetrics. You are given a question and extracted passages of reports. Provide a clear and structured answer based on the passages provided, the context and the guidelines.
|
| 41 |
|
|
@@ -50,6 +74,8 @@ Guidelines:
|
|
| 50 |
- If the documents do not have the information needed to answer the question, just say you do not have enough information.
|
| 51 |
- Consider by default that the question is about the past century unless it is specified otherwise.
|
| 52 |
- If the passage is the caption of a picture, you can still use it as part of your answer as any other document.
|
|
|
|
|
|
|
| 53 |
|
| 54 |
-----------------------
|
| 55 |
Passages:
|
|
@@ -60,7 +86,6 @@ Question: {query} - Explained to {audience}
|
|
| 60 |
Answer in {language} with the passages citations:
|
| 61 |
"""
|
| 62 |
|
| 63 |
-
|
| 64 |
papers_prompt_template = """
|
| 65 |
You are ClimateQ&A, an AI Assistant created by Ekimetrics. You are given a question and extracted abstracts of scientific papers. Provide a clear and structured answer based on the abstracts provided, the context and the guidelines.
|
| 66 |
|
|
|
|
| 36 |
"""
|
| 37 |
|
| 38 |
|
| 39 |
+
# answer_prompt_template_old = """
|
| 40 |
+
# You are ClimateQ&A, an AI Assistant created by Ekimetrics. You are given a question and extracted passages of reports. Provide a clear and structured answer based on the passages provided, the context and the guidelines.
|
| 41 |
+
|
| 42 |
+
# Guidelines:
|
| 43 |
+
# - If the passages have useful facts or numbers, use them in your answer.
|
| 44 |
+
# - When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.
|
| 45 |
+
# - Do not use the sentence 'Doc i says ...' to say where information came from.
|
| 46 |
+
# - If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]
|
| 47 |
+
# - Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
|
| 48 |
+
# - If it makes sense, use bullet points and lists to make your answers easier to understand.
|
| 49 |
+
# - You do not need to use every passage. Only use the ones that help answer the question.
|
| 50 |
+
# - If the documents do not have the information needed to answer the question, just say you do not have enough information.
|
| 51 |
+
# - Consider by default that the question is about the past century unless it is specified otherwise.
|
| 52 |
+
# - If the passage is the caption of a picture, you can still use it as part of your answer as any other document.
|
| 53 |
+
|
| 54 |
+
# -----------------------
|
| 55 |
+
# Passages:
|
| 56 |
+
# {context}
|
| 57 |
+
|
| 58 |
+
# -----------------------
|
| 59 |
+
# Question: {query} - Explained to {audience}
|
| 60 |
+
# Answer in {language} with the passages citations:
|
| 61 |
+
# """
|
| 62 |
+
|
| 63 |
answer_prompt_template = """
|
| 64 |
You are ClimateQ&A, an AI Assistant created by Ekimetrics. You are given a question and extracted passages of reports. Provide a clear and structured answer based on the passages provided, the context and the guidelines.
|
| 65 |
|
|
|
|
| 74 |
- If the documents do not have the information needed to answer the question, just say you do not have enough information.
|
| 75 |
- Consider by default that the question is about the past century unless it is specified otherwise.
|
| 76 |
- If the passage is the caption of a picture, you can still use it as part of your answer as any other document.
|
| 77 |
+
- If you receive passages from different reports, eg IPCC and PPCP, make separate paragraphs and specify the source of the information in your answer, eg "According to IPCC, ...".
|
| 78 |
+
- The different sources are IPCC, IPBES, PPCP (for Plan Climat Air Energie Territorial de Paris), PBDP (for Plan Biodiversité de Paris), Acclimaterra.
|
| 79 |
|
| 80 |
-----------------------
|
| 81 |
Passages:
|
|
|
|
| 86 |
Answer in {language} with the passages citations:
|
| 87 |
"""
|
| 88 |
|
|
|
|
| 89 |
papers_prompt_template = """
|
| 90 |
You are ClimateQ&A, an AI Assistant created by Ekimetrics. You are given a question and extracted abstracts of scientific papers. Provide a clear and structured answer based on the abstracts provided, the context and the guidelines.
|
| 91 |
|
climateqa/engine/chains/query_transformation.py
CHANGED
|
@@ -60,7 +60,8 @@ from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
|
| 60 |
|
| 61 |
|
| 62 |
ROUTING_INDEX = {
|
| 63 |
-
"
|
|
|
|
| 64 |
"OpenAlex":["OpenAlex"],
|
| 65 |
}
|
| 66 |
|
|
@@ -88,6 +89,17 @@ class Location(BaseModel):
|
|
| 88 |
country:str = Field(...,description="The country if directly mentioned or inferred from the location (cities, regions, adresses), ex: France, USA, ...")
|
| 89 |
location:str = Field(...,description="The specific place if mentioned (cities, regions, addresses), ex: Marseille, New York, Wisconsin, ...")
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
class QueryAnalysis(BaseModel):
|
| 92 |
"""
|
| 93 |
Analyze the user query to extract the relevant sources
|
|
@@ -98,14 +110,16 @@ class QueryAnalysis(BaseModel):
|
|
| 98 |
Also provide simple keywords to feed a search engine
|
| 99 |
"""
|
| 100 |
|
| 101 |
-
sources: List[Literal["IPCC", "IPBES", "IPOS", "AcclimaTerra"]] = Field( #,"OpenAlex"]] = Field(
|
| 102 |
...,
|
| 103 |
description="""
|
| 104 |
Given a user question choose which documents would be most relevant for answering their question,
|
| 105 |
- IPCC is for questions about climate change, energy, impacts, and everything we can find the IPCC reports
|
| 106 |
- IPBES is for questions about biodiversity and nature
|
| 107 |
- IPOS is for questions about the ocean and deep sea mining
|
| 108 |
-
- AcclimaTerra is for questions about any specific place in, or close to, the french region "Nouvelle-Aquitaine"
|
|
|
|
|
|
|
| 109 |
""",
|
| 110 |
)
|
| 111 |
|
|
@@ -142,7 +156,25 @@ def make_query_analysis_chain(llm):
|
|
| 142 |
|
| 143 |
|
| 144 |
prompt = ChatPromptTemplate.from_messages([
|
| 145 |
-
("system", "You are a helpful assistant, you will analyze
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
("user", "input: {input}")
|
| 147 |
])
|
| 148 |
|
|
@@ -150,6 +182,16 @@ def make_query_analysis_chain(llm):
|
|
| 150 |
chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
|
| 151 |
return chain
|
| 152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
def make_query_transform_node(llm,k_final=15):
|
| 155 |
"""
|
|
@@ -172,12 +214,13 @@ def make_query_transform_node(llm,k_final=15):
|
|
| 172 |
|
| 173 |
decomposition_chain = make_query_decomposition_chain(llm)
|
| 174 |
query_analysis_chain = make_query_analysis_chain(llm)
|
|
|
|
| 175 |
|
| 176 |
def transform_query(state):
|
| 177 |
print("---- Transform query ----")
|
| 178 |
|
| 179 |
-
auto_mode = state.get("sources_auto",
|
| 180 |
-
sources_input = state.get("sources_input", ROUTING_INDEX["
|
| 181 |
|
| 182 |
|
| 183 |
new_state = {}
|
|
@@ -186,6 +229,7 @@ def make_query_transform_node(llm,k_final=15):
|
|
| 186 |
decomposition_output = decomposition_chain.invoke({"input":state["query"]})
|
| 187 |
new_state.update(decomposition_output)
|
| 188 |
|
|
|
|
| 189 |
# Query Analysis
|
| 190 |
questions = []
|
| 191 |
for question in new_state["questions"]:
|
|
@@ -194,16 +238,32 @@ def make_query_transform_node(llm,k_final=15):
|
|
| 194 |
|
| 195 |
# TODO WARNING llm should always return smthg
|
| 196 |
# The case when the llm does not return any sources or wrong ouput
|
| 197 |
-
if not query_analysis_output["sources"] or not all(source in ["IPCC", "IPBS", "IPOS"] for source in query_analysis_output["sources"]):
|
| 198 |
query_analysis_output["sources"] = ["IPCC", "IPBES", "IPOS"]
|
| 199 |
|
| 200 |
-
|
| 201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
# Explode the questions into multiple questions with different sources
|
| 204 |
new_questions = []
|
| 205 |
for q in questions:
|
| 206 |
-
question,sources = q["question"],q["sources"]
|
| 207 |
|
| 208 |
# If not auto mode we take the configuration
|
| 209 |
if not auto_mode:
|
|
@@ -212,7 +272,7 @@ def make_query_transform_node(llm,k_final=15):
|
|
| 212 |
for index,index_sources in ROUTING_INDEX.items():
|
| 213 |
selected_sources = list(set(sources).intersection(index_sources))
|
| 214 |
if len(selected_sources) > 0:
|
| 215 |
-
new_questions.append({"question":question,"sources":selected_sources,"index":index})
|
| 216 |
|
| 217 |
# # Add the number of questions to search
|
| 218 |
# k_by_question = k_final // len(new_questions)
|
|
@@ -222,11 +282,16 @@ def make_query_transform_node(llm,k_final=15):
|
|
| 222 |
# new_state["questions"] = new_questions
|
| 223 |
# new_state["remaining_questions"] = new_questions
|
| 224 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
|
| 226 |
new_state = {
|
| 227 |
"questions_list":new_questions,
|
| 228 |
-
"n_questions":
|
| 229 |
-
"handled_questions_index":[],
|
| 230 |
}
|
| 231 |
return new_state
|
| 232 |
|
|
|
|
| 60 |
|
| 61 |
|
| 62 |
ROUTING_INDEX = {
|
| 63 |
+
"IPx":["IPCC", "IPBS", "IPOS"],
|
| 64 |
+
"POC": ["AcclimaTerra", "PCAET","Biodiv"],
|
| 65 |
"OpenAlex":["OpenAlex"],
|
| 66 |
}
|
| 67 |
|
|
|
|
| 89 |
country:str = Field(...,description="The country if directly mentioned or inferred from the location (cities, regions, adresses), ex: France, USA, ...")
|
| 90 |
location:str = Field(...,description="The specific place if mentioned (cities, regions, addresses), ex: Marseille, New York, Wisconsin, ...")
|
| 91 |
|
| 92 |
+
class QueryTranslation(BaseModel):
|
| 93 |
+
"""Translate the query into a given language"""
|
| 94 |
+
|
| 95 |
+
question : str = Field(
|
| 96 |
+
description="""
|
| 97 |
+
Translate the questions into the given language
|
| 98 |
+
If the question is alrealdy in the given language, just return the same question
|
| 99 |
+
""",
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
class QueryAnalysis(BaseModel):
|
| 104 |
"""
|
| 105 |
Analyze the user query to extract the relevant sources
|
|
|
|
| 110 |
Also provide simple keywords to feed a search engine
|
| 111 |
"""
|
| 112 |
|
| 113 |
+
sources: List[Literal["IPCC", "IPBES", "IPOS", "AcclimaTerra", "PCAET","Biodiv"]] = Field( #,"OpenAlex"]] = Field(
|
| 114 |
...,
|
| 115 |
description="""
|
| 116 |
Given a user question choose which documents would be most relevant for answering their question,
|
| 117 |
- IPCC is for questions about climate change, energy, impacts, and everything we can find the IPCC reports
|
| 118 |
- IPBES is for questions about biodiversity and nature
|
| 119 |
- IPOS is for questions about the ocean and deep sea mining
|
| 120 |
+
- AcclimaTerra is for questions about any specific place in, or close to, the french region "Nouvelle-Aquitaine"
|
| 121 |
+
- PCAET is the Plan Climat Eneregie Territorial for the city of Paris
|
| 122 |
+
- Biodiv is the Biodiversity plan for the city of Paris
|
| 123 |
""",
|
| 124 |
)
|
| 125 |
|
|
|
|
| 156 |
|
| 157 |
|
| 158 |
prompt = ChatPromptTemplate.from_messages([
|
| 159 |
+
("system", "You are a helpful assistant, you will analyze the user input message using the function provided"),
|
| 160 |
+
("user", "input: {input}")
|
| 161 |
+
])
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
|
| 165 |
+
return chain
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def make_query_translation_chain(llm):
|
| 169 |
+
"""Analyze the user query to extract the relevant sources"""
|
| 170 |
+
|
| 171 |
+
openai_functions = [convert_to_openai_function(QueryTranslation)]
|
| 172 |
+
llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"QueryTranslation"})
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
prompt = ChatPromptTemplate.from_messages([
|
| 177 |
+
("system", "You are a helpful assistant, translate the question into {language}"),
|
| 178 |
("user", "input: {input}")
|
| 179 |
])
|
| 180 |
|
|
|
|
| 182 |
chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
|
| 183 |
return chain
|
| 184 |
|
| 185 |
+
def group_by_sources_types(sources):
|
| 186 |
+
sources_types = {}
|
| 187 |
+
IPx_sources = ["IPCC", "IPBES", "IPOS"]
|
| 188 |
+
local_sources = ["AcclimaTerra", "PCAET","Biodiv"]
|
| 189 |
+
if any(source in IPx_sources for source in sources):
|
| 190 |
+
sources_types["IPx"] = list(set(sources).intersection(IPx_sources))
|
| 191 |
+
if any(source in local_sources for source in sources):
|
| 192 |
+
sources_types["POC"] = list(set(sources).intersection(local_sources))
|
| 193 |
+
return sources_types
|
| 194 |
+
|
| 195 |
|
| 196 |
def make_query_transform_node(llm,k_final=15):
|
| 197 |
"""
|
|
|
|
| 214 |
|
| 215 |
decomposition_chain = make_query_decomposition_chain(llm)
|
| 216 |
query_analysis_chain = make_query_analysis_chain(llm)
|
| 217 |
+
query_translation_chain = make_query_translation_chain(llm)
|
| 218 |
|
| 219 |
def transform_query(state):
|
| 220 |
print("---- Transform query ----")
|
| 221 |
|
| 222 |
+
auto_mode = state.get("sources_auto", True)
|
| 223 |
+
sources_input = state.get("sources_input", ROUTING_INDEX["IPx"])
|
| 224 |
|
| 225 |
|
| 226 |
new_state = {}
|
|
|
|
| 229 |
decomposition_output = decomposition_chain.invoke({"input":state["query"]})
|
| 230 |
new_state.update(decomposition_output)
|
| 231 |
|
| 232 |
+
|
| 233 |
# Query Analysis
|
| 234 |
questions = []
|
| 235 |
for question in new_state["questions"]:
|
|
|
|
| 238 |
|
| 239 |
# TODO WARNING llm should always return smthg
|
| 240 |
# The case when the llm does not return any sources or wrong ouput
|
| 241 |
+
if not query_analysis_output["sources"] or not all(source in ["IPCC", "IPBS", "IPOS","AcclimaTerra", "PCAET","Biodiv"] for source in query_analysis_output["sources"]):
|
| 242 |
query_analysis_output["sources"] = ["IPCC", "IPBES", "IPOS"]
|
| 243 |
|
| 244 |
+
sources_types = group_by_sources_types(query_analysis_output["sources"])
|
| 245 |
+
for source_type,sources in sources_types.items():
|
| 246 |
+
question_state = {
|
| 247 |
+
"question":question,
|
| 248 |
+
"sources":sources,
|
| 249 |
+
"source_type":source_type
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
questions.append(question_state)
|
| 253 |
+
|
| 254 |
+
# Translate question into the document language
|
| 255 |
+
for q in questions:
|
| 256 |
+
if q["source_type"]=="IPx":
|
| 257 |
+
translation_output = query_translation_chain.invoke({"input":q["question"],"language":"English"})
|
| 258 |
+
q["question"] = translation_output["question"]
|
| 259 |
+
elif q["source_type"]=="POC":
|
| 260 |
+
translation_output = query_translation_chain.invoke({"input":q["question"],"language":"French"})
|
| 261 |
+
q["question"] = translation_output["question"]
|
| 262 |
|
| 263 |
# Explode the questions into multiple questions with different sources
|
| 264 |
new_questions = []
|
| 265 |
for q in questions:
|
| 266 |
+
question,sources,source_type = q["question"],q["sources"], q["source_type"]
|
| 267 |
|
| 268 |
# If not auto mode we take the configuration
|
| 269 |
if not auto_mode:
|
|
|
|
| 272 |
for index,index_sources in ROUTING_INDEX.items():
|
| 273 |
selected_sources = list(set(sources).intersection(index_sources))
|
| 274 |
if len(selected_sources) > 0:
|
| 275 |
+
new_questions.append({"question":question,"sources":selected_sources,"index":index, "source_type":source_type})
|
| 276 |
|
| 277 |
# # Add the number of questions to search
|
| 278 |
# k_by_question = k_final // len(new_questions)
|
|
|
|
| 282 |
# new_state["questions"] = new_questions
|
| 283 |
# new_state["remaining_questions"] = new_questions
|
| 284 |
|
| 285 |
+
n_questions = {
|
| 286 |
+
"total":len(new_questions),
|
| 287 |
+
"IPx":len([q for q in new_questions if q["index"] == "IPx"]),
|
| 288 |
+
"POC":len([q for q in new_questions if q["index"] == "POC"]),
|
| 289 |
+
}
|
| 290 |
|
| 291 |
new_state = {
|
| 292 |
"questions_list":new_questions,
|
| 293 |
+
"n_questions":n_questions,
|
| 294 |
+
"handled_questions_index":[],
|
| 295 |
}
|
| 296 |
return new_state
|
| 297 |
|
climateqa/engine/chains/retrieve_documents.py
CHANGED
|
@@ -290,7 +290,7 @@ async def retrieve_documents(state,config, source_type, vectorstore,reranker,llm
|
|
| 290 |
Returns:
|
| 291 |
dict: The updated state containing the retrieved and reranked documents, related content, and remaining questions.
|
| 292 |
"""
|
| 293 |
-
|
| 294 |
docs = state.get("documents", [])
|
| 295 |
related_content = state.get("related_content", [])
|
| 296 |
|
|
@@ -304,26 +304,30 @@ async def retrieve_documents(state,config, source_type, vectorstore,reranker,llm
|
|
| 304 |
# remaining_questions = state["remaining_questions"][1:]
|
| 305 |
|
| 306 |
current_question_id = None
|
| 307 |
-
print("
|
| 308 |
|
| 309 |
for i in range(len(state["questions_list"])):
|
| 310 |
-
|
|
|
|
|
|
|
| 311 |
current_question_id = i
|
| 312 |
break
|
| 313 |
-
|
| 314 |
# TODO filter on source_type
|
| 315 |
|
| 316 |
-
k_by_question = k_final // state["n_questions"]
|
| 317 |
-
k_summary_by_question = _get_k_summary_by_question(state["n_questions"])
|
| 318 |
-
k_images_by_question = _get_k_images_by_question(state["n_questions"])
|
| 319 |
|
| 320 |
sources = current_question["sources"]
|
| 321 |
question = current_question["question"]
|
| 322 |
index = current_question["index"]
|
|
|
|
| 323 |
|
| 324 |
print(f"Retrieve documents for question: {question}")
|
| 325 |
await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)
|
| 326 |
|
|
|
|
| 327 |
|
| 328 |
# if index == "Vector": # always true for now #TODO rename to IPx
|
| 329 |
if source_type == "IPx": # always true for now #TODO rename to IPx
|
|
@@ -393,7 +397,7 @@ def make_IPx_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_
|
|
| 393 |
@chain
|
| 394 |
async def retrieve_IPx_docs(state, config):
|
| 395 |
source_type = "IPx"
|
| 396 |
-
return {"documents":[], "related_contents": [], "handled_questions_index": list(range(len(state["questions_list"])))} # TODO Remove
|
| 397 |
|
| 398 |
state = await retrieve_documents(
|
| 399 |
state = state,
|
|
|
|
| 290 |
Returns:
|
| 291 |
dict: The updated state containing the retrieved and reranked documents, related content, and remaining questions.
|
| 292 |
"""
|
| 293 |
+
# TODO split les questions selon le type de sources dans le state question + conditions sur le nombre de questions traités par type de source
|
| 294 |
docs = state.get("documents", [])
|
| 295 |
related_content = state.get("related_content", [])
|
| 296 |
|
|
|
|
| 304 |
# remaining_questions = state["remaining_questions"][1:]
|
| 305 |
|
| 306 |
current_question_id = None
|
| 307 |
+
print("Questions Indexs", list(range(len(state["questions_list"]))), "- Handled questions : " ,state["handled_questions_index"])
|
| 308 |
|
| 309 |
for i in range(len(state["questions_list"])):
|
| 310 |
+
current_question = state["questions_list"][i]
|
| 311 |
+
|
| 312 |
+
if i not in state["handled_questions_index"] and current_question["source_type"] == source_type:
|
| 313 |
current_question_id = i
|
| 314 |
break
|
| 315 |
+
|
| 316 |
# TODO filter on source_type
|
| 317 |
|
| 318 |
+
k_by_question = k_final // state["n_questions"]["total"]
|
| 319 |
+
k_summary_by_question = _get_k_summary_by_question(state["n_questions"]["total"])
|
| 320 |
+
k_images_by_question = _get_k_images_by_question(state["n_questions"]["total"])
|
| 321 |
|
| 322 |
sources = current_question["sources"]
|
| 323 |
question = current_question["question"]
|
| 324 |
index = current_question["index"]
|
| 325 |
+
source_type = current_question["source_type"]
|
| 326 |
|
| 327 |
print(f"Retrieve documents for question: {question}")
|
| 328 |
await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)
|
| 329 |
|
| 330 |
+
print(f"""---- Retrieve documents from {current_question["source_type"]}----""")
|
| 331 |
|
| 332 |
# if index == "Vector": # always true for now #TODO rename to IPx
|
| 333 |
if source_type == "IPx": # always true for now #TODO rename to IPx
|
|
|
|
| 397 |
@chain
|
| 398 |
async def retrieve_IPx_docs(state, config):
|
| 399 |
source_type = "IPx"
|
| 400 |
+
# return {"documents":[], "related_contents": [], "handled_questions_index": list(range(len(state["questions_list"])))} # TODO Remove
|
| 401 |
|
| 402 |
state = await retrieve_documents(
|
| 403 |
state = state,
|
climateqa/engine/graph.py
CHANGED
|
@@ -93,21 +93,40 @@ def route_based_on_relevant_docs(state,threshold_docs=0.2):
|
|
| 93 |
return "answer_rag_no_docs"
|
| 94 |
|
| 95 |
def route_continue_retrieve_documents(state):
|
| 96 |
-
|
|
|
|
|
|
|
| 97 |
return END
|
| 98 |
-
elif
|
| 99 |
-
return "answer_search"
|
| 100 |
-
else
|
| 101 |
return "retrieve_documents"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
def route_continue_retrieve_local_documents(state):
|
| 104 |
-
|
|
|
|
|
|
|
| 105 |
return END
|
| 106 |
-
elif
|
| 107 |
return "answer_search"
|
| 108 |
-
else
|
| 109 |
return "retrieve_local_data"
|
| 110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
# if len(state["remaining_questions"]) == 0 and state["search_only"] :
|
| 112 |
# return END
|
| 113 |
# elif len(state["remaining_questions"]) > 0:
|
|
@@ -216,8 +235,8 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
|
|
| 216 |
|
| 217 |
# Define the edges
|
| 218 |
workflow.add_edge("translate_query", "transform_query")
|
| 219 |
-
|
| 220 |
-
workflow.add_edge("transform_query", END) # TODO remove
|
| 221 |
|
| 222 |
workflow.add_edge("retrieve_graphs", END)
|
| 223 |
workflow.add_edge("answer_rag", END)
|
|
|
|
| 93 |
return "answer_rag_no_docs"
|
| 94 |
|
| 95 |
def route_continue_retrieve_documents(state):
|
| 96 |
+
index_question_ipx = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "IPx"]
|
| 97 |
+
questions_ipx_finished = all(elem in state["handled_questions_index"] for elem in index_question_ipx)
|
| 98 |
+
if questions_ipx_finished and state["search_only"]:
|
| 99 |
return END
|
| 100 |
+
elif questions_ipx_finished:
|
| 101 |
+
return "answer_search"
|
| 102 |
+
else:
|
| 103 |
return "retrieve_documents"
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# if state["n_questions"]["IPx"] == len(state["handled_questions_index"]) and state["search_only"] :
|
| 107 |
+
# return END
|
| 108 |
+
# elif state["n_questions"]["IPx"] == len(state["handled_questions_index"]):
|
| 109 |
+
# return "answer_search"
|
| 110 |
+
# else :
|
| 111 |
+
# return "retrieve_documents"
|
| 112 |
|
| 113 |
def route_continue_retrieve_local_documents(state):
|
| 114 |
+
index_question_poc = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
|
| 115 |
+
questions_poc_finished = all(elem in state["handled_questions_index"] for elem in index_question_poc)
|
| 116 |
+
if questions_poc_finished and state["search_only"]:
|
| 117 |
return END
|
| 118 |
+
elif questions_poc_finished:
|
| 119 |
return "answer_search"
|
| 120 |
+
else:
|
| 121 |
return "retrieve_local_data"
|
| 122 |
|
| 123 |
+
# if state["n_questions"]["POC"] == len(state["handled_questions_index"]) and state["search_only"] :
|
| 124 |
+
# return END
|
| 125 |
+
# elif state["n_questions"]["POC"] == len(state["handled_questions_index"]):
|
| 126 |
+
# return "answer_search"
|
| 127 |
+
# else :
|
| 128 |
+
# return "retrieve_local_data"
|
| 129 |
+
|
| 130 |
# if len(state["remaining_questions"]) == 0 and state["search_only"] :
|
| 131 |
# return END
|
| 132 |
# elif len(state["remaining_questions"]) > 0:
|
|
|
|
| 235 |
|
| 236 |
# Define the edges
|
| 237 |
workflow.add_edge("translate_query", "transform_query")
|
| 238 |
+
workflow.add_edge("transform_query", "retrieve_documents") #TODO put back
|
| 239 |
+
# workflow.add_edge("transform_query", END) # TODO remove
|
| 240 |
|
| 241 |
workflow.add_edge("retrieve_graphs", END)
|
| 242 |
workflow.add_edge("answer_rag", END)
|
climateqa/event_handler.py
CHANGED
|
@@ -35,7 +35,7 @@ def handle_retrieved_documents(event: StreamEvent, history : list[ChatMessage],
|
|
| 35 |
tuple[str, list[ChatMessage], list[str]]: The updated HTML representation of the documents, the updated message history and the updated list of used documents
|
| 36 |
"""
|
| 37 |
if "documents" not in event["data"]["output"] or event["data"]["output"]["documents"] == []:
|
| 38 |
-
return history, used_documents
|
| 39 |
|
| 40 |
try:
|
| 41 |
docs = event["data"]["output"]["documents"]
|
|
|
|
| 35 |
tuple[str, list[ChatMessage], list[str]]: The updated HTML representation of the documents, the updated message history and the updated list of used documents
|
| 36 |
"""
|
| 37 |
if "documents" not in event["data"]["output"] or event["data"]["output"]["documents"] == []:
|
| 38 |
+
return history, used_documents
|
| 39 |
|
| 40 |
try:
|
| 41 |
docs = event["data"]["output"]["documents"]
|