diff --git a/.gitignore b/.gitignore index d2c079820a..90618d8e52 100644 --- a/.gitignore +++ b/.gitignore @@ -78,9 +78,30 @@ download_models_hf.py memory-bank .claude/CLAUDE.md .claude/ - -# Claude Code CLAUDE.md -# Google Jules -.jules/ + +# My changes to the gitignore file +# 1. Ignore everything in the repository +/* + +# 2. Allow the 'jrs' folder itself +!/jrs/ + +# 3. Ignore everything inside 'jrs' (including the 'work' folder) +/jrs/* + +# 4. Allow specific files directly inside 'jrs' +!/jrs/*.* + +# 5. Allow the _notes folder specifically +!/jrs/_notes/ + +# 6. Allow everything inside the _notes folder +!/jrs/_notes/** + +# 7. Allow the archive folder specifically +!/jrs/archive/ + +# 8. Allow everything inside the _notes folder +!/jrs/archive/** diff --git a/jrs/_0_interesting_queries.md b/jrs/_0_interesting_queries.md new file mode 100644 index 0000000000..2c7702e02f --- /dev/null +++ b/jrs/_0_interesting_queries.md @@ -0,0 +1,144 @@ +## Ways To Query The LighRAG Server + +## Query 1 + +### LightRAG server will produce links which will bring you to the exact spot in a video where the query is addressed. + +The magic which make this possible is in the way the data was prepared and indexed. +[This library](https://github.com/johnshearing/scrape_yt_mk_transcripts) was used to prepare the data for indexing. +
+How the query is written also matters. The following is an example. + +**Paste the following query into the LightRAG server prompt box. +Try Naive, Local, Global, Hybrid, and Mix mode for each query** + + +[The source text is a transcript from the source video. The URL for the source video is found in the metadata associated with the source text. Time stamps in the source text appear between brackets at the beginning of each line as shown in the following sample text: [6.56 > 11.4] (Roger) the different years that in those that had COVID there were higher levels of oxidative stress. In the sample text, the values given between the brackets are given in seconds and decimal fractions of a second. So the first value in the timestamp sample is six seconds and 56 hundreths of a second. Only use the integer part of the value and only return the first value in the pair. So from the sample text, only use the number 6 to represent six seconds into the video when generating your answer. To say it another way, only use the whole number part of the first value given between the brackets when forming your answer. Please provide URLs in your responses for source video that incorporate the timestamps found in the source text so that the videos can be viewed at the moment which is relevant to the query response. The following is an example of how the URLs should be structured if the source text is from Ln3WszTq0uA.txt and the moment of interest has a time stamp of [6.56 > 11.68]: https://www.youtube.com/watch?v=Ln3WszTq0uA&t=6s. To build the URL always start with the URL you find in the metadata which is associated with the source text. In this case the URL associated with the source text is "https://www.youtube.com/watch?v=Ln3WszTq0uA". Append the characters &t= to the URL. So now the URL will look as follows: "https://www.youtube.com/watch?v=Ln3WszTq0uA&t=". Next, append the number 6 to the URL if the timestamp is [6.56 > 11.4] because we are only interested in the whole number part of the first value in the timestamp. So now the URL will look as follows: "https://www.youtube.com/watch?v=Ln3WszTq0uA&t=6". Finally, append the letter s to the URL to denote seconds. So the final resulting URL will look as follows: "https://www.youtube.com/watch?v=Ln3WszTq0uA&t=6s"] Please search for all instances where infrared light is discussed with regard to studies done in Brazil. + +**The following was the response produced** + +Infrared Light Studies in Brazil +In the context of infrared light and its potential benefits for COVID-19, there is a notable study mentioned that was conducted in Brazil. This study involved the use of a jacket for patients with mild to moderate COVID-19 who were hospitalized but not on ventilators. The findings indicated that the use of near-infrared light therapy significantly improved several pulmonary functions and reduced hospitalization time by four days. Specifically, there was a 32% improvement in maximum inspiratory pressure and a 23% improvement in maximum expiratory pressure. + +The researchers highlighted the systemic effects of photobiomodulation, which suggests that near-infrared light could have broader implications for treating conditions like COVID-19. However, the study's main limitation was the small cohort size of only 30 patients, suggesting that further research with larger populations is necessary to strengthen the conclusions. + +For more information on this topic, you can view the relevant section of the video discussing the Brazilian study at this timestamp: https://www.youtube.com/watch?v=Hy8G5TKxiLY&t=2102s. + +### Summary +The discussion around topical sunscreen highlights a critical view of its use in preventing skin cancer, suggesting that blocking UV light with sunscreen may lead to adverse health outcomes, including an increased risk of melanoma. + +--- +#### The following explains the effect of the query formating in the above example + + * Query methods available are naive, local, global, hybrid, mix, and bypass. + * naive only looks in the vector database for information but will give thoughtfull answers beyond what it finds. So it's not just searching for words. Rather it is searching for meaning and reporting on that. + * local uses information found in the knowledge graph nodes. + * global uses information found in knowledge graph relations. + * hybrid combines local and global methods. + * mix combines hybrid with naive + * The results vary significantly based on the method used. + * There is no telling which will give the best result so it is good to try them all. + * The python script _2_index_query_??.py does indeed try all the methods so it might be wise to run the query in the script first to see which method produces the best response and then run the query again in the LightRAG server to get the best formatting. + +* [Time stamps in the...] + * Specify special instructions or formating instructions in brackets before the query. This greatly improves the quality of the response. + + +## Query 2 + +### Metadata: LightRAG server will list data about the source material and produce links to that source material. + +The magic which make this possible is in the way the data was prepared and indexed. +[This library](https://github.com/johnshearing/scrape_yt_mk_transcripts) was used to prepare the data for indexing. +
+How the query is written also matters. The following is an example. + +**Paste the following query into the LightRAG server prompt box.** + +/naive Please list the metadata for the source material. + +**The following is was the response produced** + +### Metadata for the Source Material +- **Video Title:** Dr Jack Kruse: Light, Melanin, POMC in Human Evolution & Disease +- **Video URL:** [Watch here](https://www.youtube.com/watch?v=Ln3WszTq0uA) +- **Video Platform:** YouTube +- **Video Channel:** Max Gulhane MD +- **Posted Date:** April 2, 2023 +- **Language:** English + +### References +- [DC] Ln3WszTq0uA_metadata.json + +## Query 3 + +### LightRAG server will produce diagrams. + +How the query is written matters. +The query type is specified first: /naive +Special instructions about the output go in between brackets [] +Finally the query is written. +The following is an example. + +**Paste the following query into the LightRAG server prompt box.** + +>/mix [Use mermaid format for diagrams] Please draw a character relationship diagram for Age of Mammals + +**The following is was the response produced** + +```mermaid +graph TD; + A[Age of Mammals] -->|Describes| B[Max] + A -->|Emergence after| C[The Last Extinction Event] + A -->|Links to| D[Darwin's Theory of Evolution] + A -->|Impacted by| E[Asteroid Event] + A -->|Related to| F[Human Evolution] + A -->|Significant for| G[Dinosaurs] +``` + +### Explanation: +- The diagram illustrates the relationships surrounding the "Age of Mammals," showing its connections to key entities such as Max, the Last Extinction Event, Darwin's Theory of Evolution, the Asteroid Event, Human Evolution, and Dinosaurs. Each relationship highlights how these entities are interconnected in the context of evolutionary history. + +## Query 4 + +### LightRAG can find orphans in the knowledge graph. + +This only works in a local or mixed mode query. +And currently this will only work for 23 items in the list when using LR-Server. +When this same query is use in a python query script many more list items can be used in the query. +I am trying to discover what is making that difference. +The query type is specified first: /local +Special instructions about the output go in between brackets [] +Finally the query is written. + +The following is an example. +**Paste the following query into the LightRAG server prompt box.** + +>/local [Only provide the name of the entity in the response. Nothing else is required.] Please examine the Entity ID for all Entities in the following python list. Then please return only the Entities with a Rank of 0. [ "2023-04-02T06:06:17Z", "2023-04-22T23:01:27Z", "Alabama", "Albury", "Biological Compartments", "Cellular Health", "Circadian Biology", "Circadian Health", "Circadian Rhythm", "Cold Thermogenesis Protocol", "Cytochrome C Oxidase", "Delta Airlines", "Deuterium Depleted Water", "Dr Jack Kruse: Light, Melanin, POMC in Human Evolution & Disease", "Dr Jack Kruse: WATER, non-native EMFs & mitochondrial basis of disease | Regenerative Health Podcast", "Dr. Anthony Chafee", "Dr. Jack Kruse", "Dr. Max Gulhane", "Eureka Moment", "Farm Tour", "Health Optimization"] + +**The following is was the response produced** + +>The only Entity from your provided list with a Rank of 0 is: +> +>Albury +>This entity corresponds to the location in New South Wales where Dr. Max Gulhane is based. + +## Query 5 + +### LightRAG can find Type, Description, Rank, and File Path in the knowledge graph. In the previous example we searched from Entity to Entity ID to Rank in order to find orphans. In the following example we go through Entity to Entity ID to Type in order to return categories. + +This only works in a mixed mode query. +The query type is specified first: /mixed +Special instructions about the output go in between brackets [] +Finally the query is written. + +The following is an example. +**Paste the following query into the LightRAG server prompt box.** + +>/mix [Only provide the name of the entity in the response. Nothing else is required.] Please examine the Entity ID for all Entities in the following python list. Then please return only the Entities with a Type of Person. [ "2023-04-02T06:06:17Z", "2023-04-22T23:01:27Z", "Alabama", "Albury", "Biological Compartments", "Cellular Health", "Circadian Biology", "Circadian Health", "Circadian Rhythm", "Cold Thermogenesis Protocol", "Cytochrome C Oxidase", "Delta Airlines", "Deuterium Depleted Water", "Dr Jack Kruse: Light, Melanin, POMC in Human Evolution & Disease", "Dr Jack Kruse: WATER, non-native EMFs & mitochondrial basis of disease | Regenerative Health Podcast", "Dr. Anthony Chafee", "Dr. Jack Kruse", "Dr. Max Gulhane", "Eureka Moment", "Farm Tour", "Health Optimization"] + +**The following is was the response produced** + +>Dr. Anthony Chafee +Dr. Jack Kruse +Dr. Max Gulhane diff --git a/jrs/_1_custom_index_01.py b/jrs/_1_custom_index_01.py new file mode 100644 index 0000000000..a2feb9b720 --- /dev/null +++ b/jrs/_1_custom_index_01.py @@ -0,0 +1,170 @@ +import os +import asyncio +import logging +import logging.config +import json +import numpy as np +from lightrag import LightRAG +from lightrag.llm.openai import gpt_4o_mini_complete +from lightrag.kg.shared_storage import initialize_pipeline_status +from lightrag.utils import logger, set_verbose_debug, EmbeddingFunc +from llama_index.embeddings.openai import OpenAIEmbedding + +# import textract + +# Configuration +WORKING_DIR = "/home/js/LightRAG/jrs/work/seheult/_seheult_work_dir" +EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text-embedding-3-large") +EMBEDDING_DIM = int(os.getenv("EMBEDDING_DIM", 3072)) +API_KEY = os.getenv("EMBEDDING_BINDING_API_KEY") +MAX_TOKEN_SIZE = int(os.getenv("MAX_TOKEN_SIZE", 8192)) + +# Files to be indexed +files_2b_indexed = [ + "/home/js/LightRAG/jrs/work/seheult/seheult_metadata/_bNySyEobfY_metadata.json", + "/home/js/LightRAG/jrs/work/seheult/seheult_metadata/0m1Qekrfs7w_metadata.json", +] + + +def configure_logging(): + """Configure logging with console and rotating file handlers.""" + for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]: + logger_instance = logging.getLogger(logger_name) + logger_instance.handlers = [] + logger_instance.filters = [] + log_dir = os.getenv("LOG_DIR", os.getcwd()) + log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag_index.log")) + print(f"\nLightRAG index log file: {log_file_path}\n") + os.makedirs(os.path.dirname(log_dir), exist_ok=True) + log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) + log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) + logging.config.dictConfig( + { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": {"format": "%(levelname)s: %(message)s"}, + "detailed": { + "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + }, + }, + "handlers": { + "console": { + "formatter": "default", + "class": "logging.StreamHandler", + "stream": "ext://sys.stderr", + }, + "file": { + "formatter": "detailed", + "class": "logging.handlers.RotatingFileHandler", + "filename": log_file_path, + "maxBytes": log_max_bytes, + "backupCount": log_backup_count, + "encoding": "utf-8", + }, + }, + "loggers": { + "lightrag": { + "handlers": ["console", "file"], + "level": "INFO", + "propagate": False, + }, + }, + } + ) + logger.setLevel(logging.INFO) + set_verbose_debug(os.getenv("VERBOSE_DEBUG", "true").lower() == "true") + + +if not os.path.exists(WORKING_DIR): + os.makedirs(WORKING_DIR) + + +async def initialize_rag(): + """Initialize LightRAG with custom embedding function.""" + print("Initializing LightRAG for indexing...") + + # Initialize embedding model + embed_model = OpenAIEmbedding( + model=EMBEDDING_MODEL, api_key=API_KEY, dimensions=EMBEDDING_DIM + ) + + # Define async embedding function + async def async_embedding_func(texts): + # llama-index returns a list; we convert it to a numpy array for LightRAG + embeddings = await embed_model.aget_text_embedding_batch(texts) + return np.array(embeddings) + + # Define embedding function + embedding_func = EmbeddingFunc( + embedding_dim=EMBEDDING_DIM, + max_token_size=MAX_TOKEN_SIZE, + func=async_embedding_func, + ) + + # Initialize LightRAG + rag = LightRAG( + working_dir=WORKING_DIR, + embedding_func=embedding_func, + llm_model_func=gpt_4o_mini_complete, + ) + + await rag.initialize_storages() + await initialize_pipeline_status() + return rag + + +async def main(): + """Main function to index documents.""" + rag = None + try: + if not os.getenv("OPENAI_API_KEY") and not API_KEY: + raise ValueError( + "OPENAI_API_KEY or EMBEDDING_BINDING_API_KEY environment variable not set" + ) + rag = await initialize_rag() + + # Check which files are already indexed + indexed_files = set() + doc_status_file = os.path.join(WORKING_DIR, "kv_store_doc_status.json") + if os.path.exists(doc_status_file): + with open(doc_status_file, "r") as f: + docs = json.load(f) + indexed_files = { + doc["file_path"] + for doc in docs.values() + if doc.get("status") == "processed" and "file_path" in doc + } + print(f"Already indexed files: {indexed_files}") + + # Index new documents + for doc_path in files_2b_indexed: + if doc_path in indexed_files: + print(f"Skipping already indexed file: {doc_path}") + continue + print(f"Checking document at: {doc_path}") + if not os.path.exists(doc_path): + print(f"Document file not found at: {doc_path}, skipping...") + continue + print(f"Indexing document: {doc_path}...") + + with open(doc_path, "r") as f: + docs = json.load(f) + + await rag.ainsert_custom_kg(docs, full_doc_id=os.path.basename(doc_path)) + print(f"Indexed {doc_path}") + except Exception as e: + print(f"An error occurred: {e}") + import traceback + + traceback.print_exc() + finally: + if rag: + print("Finalizing storages...") + await rag.finalize_storages() + + +if __name__ == "__main__": + configure_logging() + asyncio.run(main()) + print("\nIndexing Done!") diff --git a/jrs/_1_ra_index.py b/jrs/_1_ra_index.py new file mode 100644 index 0000000000..121b8a6796 --- /dev/null +++ b/jrs/_1_ra_index.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python +import os +import argparse +import asyncio + +from raganything import RAGAnything, RAGAnythingConfig +from lightrag.llm.openai import openai_complete_if_cache, openai_embed +from lightrag.utils import EmbeddingFunc + + +async def run_indexing(file_path, output_dir, api_key, base_url, working_dir): + try: + config = RAGAnythingConfig( + working_dir=working_dir, + enable_image_processing=True, + enable_table_processing=True, + enable_equation_processing=True, + ) + + def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs): + return openai_complete_if_cache( + "gpt-4o-mini", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key=api_key, + base_url=base_url, + **kwargs, + ) + + def vision_model_func( + prompt, system_prompt=None, history_messages=[], image_data=None, **kwargs + ): + if image_data: + return openai_complete_if_cache( + "gpt-4o", + "", + system_prompt=None, + history_messages=[], + messages=[ + ( + {"role": "system", "content": system_prompt} + if system_prompt + else None + ), + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_data}" + }, + }, + ], + }, + ], + api_key=api_key, + base_url=base_url, + **kwargs, + ) + return llm_model_func(prompt, system_prompt, history_messages, **kwargs) + + embedding_func = EmbeddingFunc( + embedding_dim=3072, + max_token_size=8192, + func=lambda texts: openai_embed( + texts, + model="text-embedding-3-large", + api_key=api_key, + base_url=base_url, + ), + allow_extra_vectors=True, # ← correct place + ) + + rag = RAGAnything( + config=config, + llm_model_func=llm_model_func, + vision_model_func=vision_model_func, + embedding_func=embedding_func, + ) + + print(f"--- Starting Indexing: {file_path} ---") + await rag.process_document_complete( + file_path=file_path, output_dir=output_dir, parse_method="auto" + ) + + if hasattr(rag, "close"): + res = rag.close() + if asyncio.iscoroutine(res): + await res + print("--- Indexing Success ---") + + except Exception as e: + print(f"Indexing Error: {e}") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-f", + "--file_path", + default="/home/js/LightRAG-Dev/jrs/work/music/music_data/lesson02.pdf", + ) + parser.add_argument( + "-w", + "--working_dir", + default="/home/js/LightRAG-Dev/jrs/work/music/music_work_dir", + ) + parser.add_argument( + "-o", + "--output", + default="/home/js/LightRAG-Dev/jrs/work/music/music_output_dir", + ) + parser.add_argument("--api-key", default=os.getenv("OPENAI_API_KEY")) + parser.add_argument("--base-url", default=os.getenv("OPENAI_BASE_URL")) + args = parser.parse_args() + + os.makedirs(args.output, exist_ok=True) + os.makedirs(args.working_dir, exist_ok=True) + asyncio.run( + run_indexing( + args.file_path, args.output, args.api_key, args.base_url, args.working_dir + ) + ) + + +if __name__ == "__main__": + main() diff --git a/jrs/_2_ra_query_image.py b/jrs/_2_ra_query_image.py new file mode 100644 index 0000000000..765c33a78f --- /dev/null +++ b/jrs/_2_ra_query_image.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python + +# For usage instructions enter the following command: +# python3 ~/LightRAG/jrs/_2_ra_query_image.py --help + +import os +import argparse +import asyncio + +from raganything import RAGAnything, RAGAnythingConfig +from lightrag.llm.openai import openai_complete_if_cache, openai_embed +from lightrag.utils import EmbeddingFunc + + +async def run_image_query( + query_text, api_key, base_url, working_dir, modes, output_file +): + try: + config = RAGAnythingConfig( + working_dir=working_dir, + enable_image_processing=True, + enable_table_processing=True, + ) + + # 1. Text LLM + def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs): + return openai_complete_if_cache( + "gpt-4o-mini", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key=api_key, + base_url=base_url, + **kwargs, + ) + + # 2. Vision LLM (Crucial for Image Queries) + def vision_model_func( + prompt, system_prompt=None, history_messages=[], image_data=None, **kwargs + ): + if image_data: + return openai_complete_if_cache( + "gpt-4o", + "", + system_prompt=None, + history_messages=[], + messages=[ + ( + {"role": "system", "content": system_prompt} + if system_prompt + else None + ), + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_data}" + }, + }, + ], + }, + ], + api_key=api_key, + base_url=base_url, + **kwargs, + ) + return llm_model_func(prompt, system_prompt, history_messages, **kwargs) + + embedding_func = EmbeddingFunc( + embedding_dim=3072, + max_token_size=8192, + func=lambda texts: openai_embed( + texts, + model="text-embedding-3-large", + api_key=api_key, + base_url=base_url, + ), + ) + + rag = RAGAnything( + config=config, + llm_model_func=llm_model_func, + vision_model_func=vision_model_func, + embedding_func=embedding_func, + ) + + print("INFO: Initializing Multimodal Engine...") + await rag._ensure_lightrag_initialized() + + # Prepare Markdown File + with open(output_file, "a", encoding="utf-8") as f: + f.write(f"\n# Query: {query_text}\n") + f.write(f"**Working Dir:** `{working_dir}`\n\n") + + # --- MULTI-MODE QUERY LOOP --- + for current_mode in modes: + print(f"\n>>> Executing [ {current_mode.upper()} ] mode...") + + try: + # query_with_multimodal is the method for reasoning over indexed visuals + result = await rag.aquery_with_multimodal(query_text, mode=current_mode) + + # Output to Console + print(f"\n[ {current_mode.upper()} ANSWER ]:") + print(f"{result}") + + # Output to Markdown File + with open(output_file, "a", encoding="utf-8") as f: + f.write(f"## Mode: {current_mode.upper()}\n") + f.write(f"{result}\n\n") + f.write("---\n") + + except Exception as e: + error_msg = f"Error in {current_mode} mode: {e}" + print(error_msg) + with open(output_file, "a", encoding="utf-8") as f: + f.write(f"### Mode: {current_mode.upper()} (FAILED)\n") + f.write(f"Error: {error_msg}\n\n") + + # Cleanup attempts (preserving existing logic) + if hasattr(rag, "finalize_storages"): + res = rag.finalize_storages() + if asyncio.iscoroutine(res): + await res + + if hasattr(rag, "lightrag") and rag.lightrag: + if hasattr(rag.lightrag, "storage") and hasattr( + rag.lightrag.storage, "close" + ): + await rag.lightrag.storage.close() + + del rag + + except Exception as e: + print(f"Query Error: {e}") + + +def main(): + parser = argparse.ArgumentParser(description="Multimodal Image Query Script") + parser.add_argument("query", help="Your question about the images/charts") + + # Modes parameter: Split by comma to allow multiple (e.g., -m naive,hybrid) + parser.add_argument( + "--modes", + "-m", + default="naive,local,global,hybrid,mix", + help="Comma-separated list of modes: naive,local,global,hybrid,mix", + ) + + # File parameter: Defaulting to LightRAG directory + parser.add_argument( + "--file", + "-f", + default="/home/js/LightRAG-Dev/mm_query_output.md", + help="Path to the output markdown file", + ) + + parser.add_argument( + "--working_dir", + "-w", + default="/home/js/LightRAG-Dev/jrs/work/music/music_work_dir", + help="Path to directory where index of knowledge is stored", + ) + + args = parser.parse_args() + + # Convert the comma-separated string into a clean Python list + mode_list = [m.strip().lower() for m in args.modes.split(",")] + + asyncio.run( + run_image_query( + args.query, + os.getenv("OPENAI_API_KEY"), + os.getenv("OPENAI_BASE_URL"), + args.working_dir, + mode_list, + args.file, + ) + ) + + +if __name__ == "__main__": + main() diff --git a/jrs/_2_ra_query_text.py b/jrs/_2_ra_query_text.py new file mode 100644 index 0000000000..8b8a34a39d --- /dev/null +++ b/jrs/_2_ra_query_text.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python + +# For usage instructions enter the following command: +# python3 path_to_this_script/_2_ra_query_text.py --help + +import os +import argparse +import asyncio + +from raganything import RAGAnything, RAGAnythingConfig +from lightrag.llm.openai import openai_complete_if_cache, openai_embed +from lightrag.utils import EmbeddingFunc + + +async def run_text_query( + query_text, api_key, base_url, working_dir, modes, output_file +): + try: + # 1. Setup Config + config = RAGAnythingConfig( + working_dir=working_dir, + enable_image_processing=True, + enable_table_processing=True, + enable_equation_processing=True, + ) + + # 2. Setup LLM Function (Standard text completion) + def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs): + return openai_complete_if_cache( + "gpt-4o-mini", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key=api_key, + base_url=base_url, + **kwargs, + ) + + # 3. Setup Embedding Function (Must match the indexing phase) + embedding_func = EmbeddingFunc( + embedding_dim=3072, + max_token_size=8192, + func=lambda texts: openai_embed( + texts, + model="text-embedding-3-large", + api_key=api_key, + base_url=base_url, + ), + ) + + # 4. Initialize RAGAnything + rag = RAGAnything( + config=config, llm_model_func=llm_model_func, embedding_func=embedding_func + ) + + # --- INITIALIZATION --- + print("INFO: Connecting to existing index...") + await rag._ensure_lightrag_initialized() + + if not rag.lightrag: + raise RuntimeError(f"Failed to load LightRAG from {working_dir}.") + + # Prepare Markdown File Entry + with open(output_file, "a", encoding="utf-8") as f: + f.write(f"\n# Text Query: {query_text}\n") + f.write(f"**Working Dir:** `{working_dir}`\n\n") + + # --- MULTI-MODE QUERY LOOP --- + for current_mode in modes: + print(f"\n>>> Executing [ {current_mode.upper()} ] mode...") + + try: + # Standard text query (aquery) + result = await rag.aquery(query_text, mode=current_mode) + + # Console Output + print(f"\n[ {current_mode.upper()} ANSWER ]:") + print(f"{result}") + + # File Output + with open(output_file, "a", encoding="utf-8") as f: + f.write(f"## Mode: {current_mode.upper()}\n") + f.write(f"{result}\n\n") + f.write("---\n") + + except Exception as e: + error_msg = f"Error in {current_mode} mode: {e}" + print(error_msg) + with open(output_file, "a", encoding="utf-8") as f: + f.write(f"### Mode: {current_mode.upper()} (FAILED)\n") + f.write(f"Error: {error_msg}\n\n") + + # --- CLEANUP --- + if hasattr(rag, "finalize_storages"): + res = rag.finalize_storages() + if asyncio.iscoroutine(res): + await res + + if hasattr(rag, "lightrag") and rag.lightrag: + if hasattr(rag.lightrag, "storage") and hasattr( + rag.lightrag.storage, "close" + ): + await rag.lightrag.storage.close() + + del rag + + except Exception as e: + print(f"Query Error: {e}") + + +def main(): + parser = argparse.ArgumentParser(description="Multi-Mode Text Query Script") + parser.add_argument("query", help="The question you want to ask about the text") + + # Modes parameter: Split by comma to allow multiple (e.g., -m naive,hybrid) + parser.add_argument( + "--modes", + "-m", + default="hybrid", + help="Comma-separated list of modes: naive,local,global,hybrid,mix", + ) + + # File parameter: Defaulting to LightRAG directory + parser.add_argument( + "--file", + "-f", + default="/home/js/LightRAG/text_query_output.md", + help="Path to the output markdown file", + ) + + parser.add_argument( + "--working_dir", + "-w", + default="/home/js/LightRAG/jrs/work/seheult/_ra/nir_through_fabrics/_ra_seheult_work_dir", + help="Path to directory where index of knowledge is stored", + ) + + args = parser.parse_args() + + # Process the mode string into a list + mode_list = [m.strip().lower() for m in args.modes.split(",")] + + asyncio.run( + run_text_query( + args.query, + os.getenv("OPENAI_API_KEY"), + os.getenv("OPENAI_BASE_URL"), + args.working_dir, + mode_list, + args.file, + ) + ) + + +if __name__ == "__main__": + main() diff --git a/jrs/_3_merge_GUI_58.py b/jrs/_3_merge_GUI_58.py new file mode 100644 index 0000000000..e23e0c0d59 --- /dev/null +++ b/jrs/_3_merge_GUI_58.py @@ -0,0 +1,2039 @@ +import tkinter as tk +from tkinter import ttk, messagebox +import asyncio +import requests +import os +import json +import platform +import urllib.parse +import subprocess +from lightrag import LightRAG +from lightrag.llm.openai import gpt_4o_mini_complete +from lightrag.kg.shared_storage import initialize_pipeline_status +from lightrag.utils import EmbeddingFunc +from llama_index.embeddings.openai import OpenAIEmbedding + +# Configuration +WORKING_DIR = "/home/js/LightRAG/jrs/work/mod_linx_text/mod_linx_work_dir" +EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text-embedding-3-large") +EMBEDDING_DIM = int(os.getenv("EMBEDDING_DIM", 3072)) +API_KEY = os.getenv("EMBEDDING_BINDING_API_key") +MAX_TOKEN_SIZE = int(os.getenv("MAX_TOKEN_SIZE", 8192)) +LIGHTRAG_SERVER_URL = "http://localhost:9621" + + +async def initialize_rag(): + embed_model = OpenAIEmbedding( + model=EMBEDDING_MODEL, api_key=API_KEY, dimensions=EMBEDDING_DIM + ) + + async def async_embedding_func(texts): + return embed_model.get_text_embedding_batch(texts) + + embedding_func = EmbeddingFunc( + embedding_dim=EMBEDDING_DIM, + max_token_size=MAX_TOKEN_SIZE, + func=async_embedding_func, + ) + rag = LightRAG( + working_dir=WORKING_DIR, + embedding_func=embedding_func, + llm_model_func=gpt_4o_mini_complete, + ) + await rag.initialize_storages() + await initialize_pipeline_status() + return rag + + +def fetch_entities(): + try: + response = requests.get(f"{LIGHTRAG_SERVER_URL}/graph/label/list") + response.raise_for_status() + entities = sorted(response.json(), key=lambda x: x.lower()) + print(f"Number of entities fetched: {len(entities)}") # Logging + return entities + except requests.exceptions.ConnectionError: + messagebox.showerror( + "Connection Error", "Could not connect to LightRAG server. Is it running?" + ) + return [] + except Exception as e: + messagebox.showerror("Error", f"Failed to fetch entities from server: {e}") + return [] + + +def fetch_entity_details(label): + try: + encoded_label = urllib.parse.quote_plus(label) + response = requests.get( + f"{LIGHTRAG_SERVER_URL}/graphs?label={encoded_label}&max_depth=1&max_nodes=20000" + ) + response.raise_for_status() + data = response.json() + + main_entity_desc = "No description found." + main_entity_type = "" + main_entity_source_id = "" + main_entity_file_path = "" + + related_nodes_info = [] + + for node in data.get("nodes", []): + node_id = node.get("id") + node_properties = node.get("properties", {}) + node_desc = node_properties.get("description", "No description found.") + node_type = node_properties.get("entity_type", "") + + if node_id == label: + main_entity_desc = node_desc + main_entity_type = node_type + main_entity_source_id = node_properties.get("source_id", "") + main_entity_file_path = node_properties.get("file_path", "") + else: + related_nodes_info.append( + {"id": node_id, "description": node_desc, "type": node_type} + ) + + edges_info = [] + for edge in data.get("edges", []): + edges_info.append( + { + "source": edge.get("source", ""), + "target": edge.get("target", ""), + "description": edge.get("properties", {}).get("description", ""), + "keywords": edge.get("properties", {}).get("keywords", ""), + "weight": edge.get("properties", {}).get("weight", 1.0), + } + ) + + return { + "desc": main_entity_desc, + "type": main_entity_type, + "srcid": main_entity_source_id, + "fpath": main_entity_file_path, + "related_nodes": related_nodes_info, + "edges": edges_info, + } + except requests.exceptions.ConnectionError: + print( + "Connection Error: Could not connect to LightRAG server to fetch entity details." + ) + return { + "desc": "Error: Server not reachable.", + "type": "", + "srcid": "", + "fpath": "", + "related_nodes": [], + "edges": [], + } + except Exception as e: + print(f"Error fetching entity details for {label}: {e}") + return { + "desc": f"Error: {e}", + "type": "", + "srcid": "", + "fpath": "", + "related_nodes": [], + "edges": [], + } + + +def trigger_server_refresh(): + try: + print("Attempting to trigger LightRAG server data refresh...") + response = requests.post(f"{LIGHTRAG_SERVER_URL}/graph/refresh-data") + response.raise_for_status() + print("LightRAG server data refresh triggered successfully.") + return True + except requests.exceptions.ConnectionError: + messagebox.showwarning( + "Server Not Running", + "Could not connect to LightRAG server to trigger refresh. Please ensure the server is running.", + ) + return False + except requests.exceptions.HTTPError as e: + messagebox.showwarning( + "API Error", + f"Failed to trigger LightRAG server data refresh: {e.response.status_code} - {e.response.text}", + ) + return False + except Exception as e: + messagebox.showwarning( + "Error", + f"An unexpected error occurred while triggering LightRAG server data refresh: {e}", + ) + return False + + +def update_entity_description_api(entity_label, new_description): + try: + url = f"{LIGHTRAG_SERVER_URL}/graph/entity/edit" + headers = {"accept": "application/json", "Content-Type": "application/json"} + payload = { + "entity_name": entity_label, + "updated_data": {"description": new_description}, + "allow_rename": False, + } + + print(f"Sending update request for {entity_label} with new description...") + response = requests.post(url, headers=headers, json=payload) + response.raise_for_status() + print(f"Successfully updated description for {entity_label}.") + return True + except requests.exceptions.ConnectionError: + messagebox.showerror( + "Connection Error", + "Could not connect to LightRAG server to update description. Is it running?", + ) + return False + except requests.exceptions.HTTPError as e: + messagebox.showerror( + "API Error", + f"Failed to update description for {entity_label}: {e.response.status_code} - {e.response.text}", + ) + return False + except Exception as e: + messagebox.showerror( + "Error", + f"An unexpected error occurred while updating description for {entity_label}: {e}", + ) + return False + + +def update_relationship_api( + source_id, target_id, new_description, new_keywords, weight=1.0 +): + try: + url = f"{LIGHTRAG_SERVER_URL}/graph/relation/edit" + headers = {"accept": "application/json", "Content-Type": "application/json"} + payload = { + "source_id": source_id, + "target_id": target_id, + "updated_data": { + "description": new_description, + "keywords": new_keywords, + "weight": weight, + }, + } + + print( + f"Sending update request for relationship from {source_id} to {target_id}..." + ) + response = requests.post(url, headers=headers, json=payload) + response.raise_for_status() + print(f"Successfully updated relationship from {source_id} to {target_id}.") + return True + except requests.exceptions.ConnectionError: + messagebox.showerror( + "Connection Error", + "Could not connect to LightRAG server to update relationship. Is it running?", + ) + return False + except requests.exceptions.HTTPError as e: + messagebox.showerror( + "API Error", + f"Failed to update relationship from {source_id} to {target_id}: {e.response.status_code} - {e.response.text}", + ) + return False + except Exception as e: + messagebox.showerror( + "Error", + f"An unexpected error occurred while updating relationship from {source_id} to {target_id}: {e}", + ) + return False + + +def create_relationship_api( + source_id, target_id, description, keywords, weight, source_file_id +): + try: + rag = asyncio.run(initialize_rag()) + relationship_data = { + "relationships": [ + { + "src_id": source_id, + "tgt_id": target_id, + "description": description, + "keywords": keywords, + "weight": float(weight), + "source_id": source_file_id, + } + ] + } + asyncio.run( + rag.ainsert_custom_kg( + relationship_data, full_doc_id=source_file_id, file_path=source_file_id + ) + ) + print(f"Successfully created relationship from {source_id} to {target_id}.") + return True + except requests.exceptions.ConnectionError: + messagebox.showerror( + "Connection Error", + "Could not connect to LightRAG server to create relationship. Is it running?", + ) + return False + except requests.exceptions.HTTPError as e: + messagebox.showerror( + "API Error", + f"Failed to create relationship from {source_id} to {target_id}: {e.response.status_code} - {e.response.text}", + ) + return False + except Exception as e: + messagebox.showerror( + "Error", + f"An unexpected error occurred while creating relationship from {source_id} to {target_id}: {e}", + ) + return False + finally: + if "rag" in locals(): + asyncio.run(rag.finalize_storages()) + + +class MergeGUI: + def __init__(self, root): + self.root = root + self.root.title("LightRAG Entity Merger") + self.entity_list = fetch_entities() + self.filtered_entity_list = self.entity_list.copy() + self.current_page = 0 # Initialize page counter + self.page_size = 35 # Number of entities per page + self.all_check_vars = {entity: tk.BooleanVar() for entity in self.entity_list} + self.description_windows = {} + self.description_frames = {} + self.entity_data = {} + self.config_file = "merge_gui_config.json" + + self.load_window_config() + self.setup_main_window() + + self.check_vars = {} + self.first_entity_var = tk.StringVar() + + self.create_ui() + + self.root.protocol("WM_DELETE_WINDOW", self.on_closing) + + def copy_selected_text(self, widget): + try: + if isinstance(widget, tk.Text): + # For Text widgets + selected_text = widget.get("sel.first", "sel.last") + elif isinstance(widget, ttk.Combobox): + # For Combobox widgets + if widget.selection_present(): # Check if there's a selection + selected_text = widget.selection_get() + else: + selected_text = ( + widget.get() + ) # Fallback to entire text if no selection + else: + raise ValueError("Unsupported widget type") + + self.root.clipboard_clear() + self.root.clipboard_append(selected_text) + self.root.update() + subprocess.run( + ["xclip", "-selection", "clipboard"], + input=selected_text.encode(), + check=True, + ) + print(f"Copied to clipboard: {selected_text}") + except tk.TclError: + print("No text selected to copy.") + except subprocess.CalledProcessError as e: + print(f"Failed to copy to clipboard via xclip: {e}") + except ValueError as e: + print(f"Error: {e}") + + def load_window_config(self): + self.window_config = { + "geometry": "1200x800", + "state": "normal", + "paned_position": 300, + } + + if os.path.exists(self.config_file): + try: + with open(self.config_file, "r") as f: + saved_config = json.load(f) + self.window_config.update(saved_config) + except Exception: + pass + + def set_initial_paned_position(self): + try: + self.root.update_idletasks() + window_width = self.paned_window.winfo_width() + + if window_width > 100: + position = int(window_width * 0.25) + else: + position = 300 + + self.paned_window.sashpos(0, position) + except Exception: + self.paned_window.sashpos(0, 300) + + def save_window_config(self): + try: + geometry = self.root.geometry() + state = self.root.state() + paned_position = self.paned_window.sash_coord(0)[0] + + config = { + "geometry": geometry, + "state": state, + "paned_position": paned_position, + } + + with open(self.config_file, "w") as f: + json.dump(config, f) + except Exception: + pass + + def setup_main_window(self): + self.root.geometry(self.window_config["geometry"]) + + if platform.system() == "Windows": + try: + self.root.state("zoomed") + except tk.TclError: + self.root.geometry( + f"{self.root.winfo_screenwidth()}x{self.root.winfo_screenheight()}+0+0" + ) + elif platform.system() == "Linux": + self.root.geometry( + f"{self.root.winfo_screenwidth()}x{self.root.winfo_screenheight()}+0+0" + ) + else: + self.root.geometry( + f"{self.root.winfo_screenwidth()}x{self.root.winfo_screenheight()}+0+0" + ) + + self.root.grid_rowconfigure(0, weight=1) + self.root.grid_columnconfigure(0, weight=1) + + def create_ui(self): + self.paned_window = ttk.PanedWindow(self.root, orient="horizontal") + self.paned_window.grid(row=0, column=0, sticky="nsew") + + self.left_panel = ttk.Frame(self.paned_window) + self.paned_window.add(self.left_panel, weight=1) + + self.right_panel = ttk.Frame(self.paned_window) + self.paned_window.add(self.right_panel, weight=3) + + self.root.after(10, self.set_initial_paned_position) + + self.create_right_panel() + self.create_left_panel() + + def create_left_panel(self): + self.left_panel.grid_rowconfigure(2, weight=1) + self.left_panel.grid_columnconfigure(0, weight=1) + + top_controls_frame = ttk.Frame(self.left_panel) + top_controls_frame.grid(row=0, column=0, sticky="ew", padx=5, pady=5) + top_controls_frame.grid_columnconfigure(1, weight=1) + + ttk.Label(top_controls_frame, text="Filter:").grid( + row=0, column=0, padx=(0, 5), sticky="w" + ) + self.filter_var = tk.StringVar() + self.filter_entry = ttk.Entry(top_controls_frame, textvariable=self.filter_var) + self.filter_entry.grid(row=0, column=1, sticky="ew", padx=(0, 5)) + self.filter_var.trace("w", self.on_filter_change) + + # Pagination controls + pagination_frame = ttk.Frame(self.left_panel) + pagination_frame.grid(row=1, column=0, sticky="ew", padx=5, pady=5) + pagination_frame.grid_columnconfigure(2, weight=1) + + ttk.Button( + pagination_frame, text="First", command=lambda: self.change_page_to(0) + ).grid(row=0, column=0, padx=2) + ttk.Button( + pagination_frame, text="Previous", command=lambda: self.change_page(-1) + ).grid(row=0, column=1, padx=2) + self.page_label = ttk.Label(pagination_frame, text="Page 1") + self.page_label.grid(row=0, column=2, sticky="w") + ttk.Button( + pagination_frame, text="Next", command=lambda: self.change_page(1) + ).grid(row=0, column=3, padx=2) + ttk.Button(pagination_frame, text="Last", command=self.go_to_last_page).grid( + row=0, column=4, padx=2 + ) + + ttk.Label(pagination_frame, text="Go to page:").grid( + row=1, column=1, padx=(10, 2) + ) + self.page_entry = ttk.Entry(pagination_frame, width=5) + self.page_entry.grid(row=1, column=2, padx=2) + ttk.Button(pagination_frame, text="Go", command=self.jump_to_page).grid( + row=1, column=3, padx=2 + ) + + header_frame = ttk.Frame(pagination_frame) + header_frame.grid(row=2, column=0, columnspan=8, sticky="ew", padx=5, pady=5) + header_frame.grid_columnconfigure(0, weight=0) + header_frame.grid_columnconfigure(1, weight=0) + header_frame.grid_columnconfigure(2, weight=1) + + self.desc_header_label = ttk.Label( + header_frame, + text="Show\nDesc", + font=("TkDefaultFont", 9, "bold"), + anchor="center", + justify="center", + ) + self.desc_header_label.grid(row=0, column=0, padx=5, sticky="ew") + + self.keep_first_label = ttk.Label( + header_frame, + text="Keep\nFirst", + font=("TkDefaultFont", 9, "bold"), + anchor="center", + justify="center", + ) + self.keep_first_label.grid(row=0, column=1, padx=5, sticky="ew") + + self.select_entities_label = ttk.Label( + header_frame, + text="Select\nEntities", + font=("TkDefaultFont", 9, "bold"), + anchor="w", + justify="left", + ) + self.select_entities_label.grid(row=0, column=2, padx=5, sticky="w") + + action_buttons_frame = ttk.Frame(top_controls_frame) + action_buttons_frame.grid( + row=1, column=0, columnspan=2, sticky="ew", pady=(5, 0) + ) + action_buttons_frame.grid_columnconfigure(0, weight=1) + action_buttons_frame.grid_columnconfigure(1, weight=1) + action_buttons_frame.grid_columnconfigure(2, weight=1) + action_buttons_frame.grid_rowconfigure(0, weight=1) + action_buttons_frame.grid_rowconfigure(1, weight=1) + + self.select_all_of_entity_button = ttk.Button( + action_buttons_frame, + text="Select All Of Type", + command=self.select_all_of_entity_type, + ) + self.select_all_of_entity_button.grid(row=0, column=0, sticky="ew", padx=(0, 2)) + + self.select_all_orphans_button = ttk.Button( + action_buttons_frame, text="Select Orphans", command=self.select_all_orphans + ) + self.select_all_orphans_button.grid(row=0, column=1, sticky="ew", padx=(2, 2)) + + self.clear_selected_button = ttk.Button( + action_buttons_frame, + text="Clear Selected", + command=self.clear_selected_entities, + ) + self.clear_selected_button.grid(row=0, column=2, sticky="ew", padx=(2, 2)) + + self.show_selected_button = ttk.Button( + action_buttons_frame, text="Show Selected", command=self.show_selected_only + ) + self.show_selected_button.grid(row=1, column=0, sticky="ew", padx=(0, 2)) + + self.clear_filter_button = ttk.Button( + action_buttons_frame, text="Show All", command=self.clear_filter + ) + self.clear_filter_button.grid(row=1, column=1, sticky="ew", padx=(2, 2)) + + self.clear_all_button = ttk.Button( + action_buttons_frame, text="Reset All", command=self.clear_all_selections + ) + self.clear_all_button.grid(row=1, column=2, sticky="ew", padx=(2, 0)) + + content_frame = ttk.Frame(self.left_panel) + content_frame.grid(row=2, column=0, sticky="nsew", padx=5, pady=5) + content_frame.grid_rowconfigure(0, weight=1) + content_frame.grid_columnconfigure(1, weight=1) + + self.toggle_desc_button = ttk.Button( + content_frame, text="⇕", command=self.toggle_descriptions, width=3 + ) + self.toggle_desc_button.grid(row=0, column=0, sticky="ns", padx=(0, 5)) + + list_frame = ttk.Frame(content_frame) + list_frame.grid(row=0, column=1, sticky="nsew") + list_frame.grid_rowconfigure(0, weight=1) + list_frame.grid_columnconfigure(0, weight=1) + + self.canvas = tk.Canvas(list_frame) + self.scrollbar = ttk.Scrollbar( + list_frame, orient="vertical", command=self.canvas.yview + ) + self.scrollable_frame = ttk.Frame(self.canvas) + + self.scrollable_frame.bind( + "", + lambda e: self.canvas.configure(scrollregion=self.canvas.bbox("all")), + ) + self.canvas.create_window((0, 0), window=self.scrollable_frame, anchor="nw") + self.canvas.configure(yscrollcommand=self.scrollbar.set) + + self.canvas.grid(row=0, column=0, sticky="nsew") + self.scrollbar.grid(row=0, column=1, sticky="ns") + + self.scrollable_frame.grid_columnconfigure(0, weight=0) + self.scrollable_frame.grid_columnconfigure(1, weight=1) + + def _on_mousewheel(event): + try: + if platform.system() == "Windows": + if hasattr(event, "delta") and event.delta: + self.canvas.yview_scroll(int(-1 * (event.delta / 120)), "units") + elif platform.system() == "Darwin": + if hasattr(event, "delta") and event.delta: + self.canvas.yview_scroll(int(-1 * event.delta), "units") + else: + if event.num == 4: + self.canvas.yview_scroll(-1, "units") + elif event.num == 5: + self.canvas.yview_scroll(1, "units") + elif hasattr(event, "delta") and event.delta: + self.canvas.yview_scroll(int(-1 * (event.delta / 120)), "units") + except Exception: + pass + return "break" + + scroll_events = [] + if platform.system() == "Windows": + scroll_events = ["", ""] + elif platform.system() == "Darwin": + scroll_events = ["", "", ""] + else: + scroll_events = ["", "", ""] + + for event in scroll_events: + try: + self.canvas.bind(event, _on_mousewheel) + except Exception: + pass + + def on_canvas_enter(event): + self.canvas.focus_set() + + def on_canvas_leave(event): + self.root.focus_set() + + self.canvas.bind("", on_canvas_enter) + self.canvas.bind("", on_canvas_leave) + self.canvas.config(takefocus=True) + + def _on_key(event): + if event.keysym == "Up": + self.canvas.yview_scroll(-1, "units") + return "break" + elif event.keysym == "Down": + self.canvas.yview_scroll(1, "units") + return "break" + elif event.keysym == "Page_Up": + self.canvas.yview_scroll(-5, "units") + return "break" + elif event.keysym == "Page_Down": + self.canvas.yview_scroll(5, "units") + return "break" + + self.canvas.bind("", _on_key) + self.canvas.bind("", lambda e: e.widget.focus_set()) + + self.create_entity_list() + + def create_entity_list(self): + for widget in self.scrollable_frame.winfo_children(): + widget.destroy() + + self.check_vars.clear() + + # Calculate pagination + start_idx = self.current_page * self.page_size + end_idx = min(start_idx + self.page_size, len(self.filtered_entity_list)) + paginated_entities = self.filtered_entity_list[start_idx:end_idx] + + # Update page label + total_pages = ( + len(self.filtered_entity_list) + self.page_size - 1 + ) // self.page_size + self.page_label.config(text=f"Pg {self.current_page + 1} of {total_pages}") + + for i, ent in enumerate(paginated_entities): + rb = ttk.Radiobutton( + self.scrollable_frame, + text="", + variable=self.first_entity_var, + value=ent, + ) + rb.grid(row=i, column=0, padx=(5, 2), pady=1, sticky="w") + + var = self.all_check_vars.get(ent) + if var is None: + var = tk.BooleanVar() + self.all_check_vars[ent] = var + + cb = ttk.Checkbutton( + self.scrollable_frame, + text=ent, + variable=var, + command=self.update_selection, + ) + cb.grid(row=i, column=1, padx=(2, 5), pady=1, sticky="w") + + current_first_entity = self.first_entity_var.get() + if current_first_entity and current_first_entity in self.filtered_entity_list: + self.first_entity_var.set(current_first_entity) + else: + self.first_entity_var.set("") + + self.update_selection() + + def change_page(self, direction): + total_pages = ( + len(self.filtered_entity_list) + self.page_size - 1 + ) // self.page_size + self.current_page = self.current_page + direction + if self.current_page < 0: + self.current_page = 0 + elif self.current_page >= total_pages: + self.current_page = total_pages - 1 + self.create_entity_list() + + def change_page_to(self, page): + total_pages = ( + len(self.filtered_entity_list) + self.page_size - 1 + ) // self.page_size + self.current_page = max(0, min(page, total_pages - 1)) + self.create_entity_list() + + def go_to_last_page(self): + total_pages = ( + len(self.filtered_entity_list) + self.page_size - 1 + ) // self.page_size + self.current_page = total_pages - 1 + self.create_entity_list() + + def jump_to_page(self): + try: + page_num = int(self.page_entry.get()) - 1 + self.change_page_to(page_num) + except ValueError: + messagebox.showerror("Invalid Input", "Please enter a valid page number.") + + def clear_selected_entities(self): + for var in self.all_check_vars.values(): + var.set(False) + self.first_entity_var.set("") + self.update_selection() + self.create_entity_list() + + def on_filter_change(self, *args): + self.current_page = 0 # Reset to first page on filter change + filter_text = self.filter_var.get().lower() + + if not filter_text: + self.filtered_entity_list = self.entity_list.copy() + else: + self.filtered_entity_list = [ + entity for entity in self.entity_list if filter_text in entity.lower() + ] + + self.create_entity_list() + + def clear_filter(self): + self.filter_var.set("") + self.current_page = 0 + self.create_entity_list() + + def select_all_of_entity_type(self): + selected_type = self.entity_type.get() + if not selected_type: + messagebox.showinfo( + "No Type Selected", + "Please select an Entity Type from the dropdown first.", + ) + return + + for var in self.all_check_vars.values(): + var.set(False) + + entities_of_selected_type = [] + for label in self.entity_list: + if label not in self.entity_data: + self.entity_data[label] = fetch_entity_details(label) + + if self.entity_data[label]["type"] == selected_type: + self.all_check_vars[label].set(True) + entities_of_selected_type.append(label) + + self.filter_var.set("") + self.current_page = 0 + self.filtered_entity_list = sorted( + entities_of_selected_type, key=lambda x: x.lower() + ) + self.create_entity_list() + + if not entities_of_selected_type: + messagebox.showinfo( + "No Entities Found", f"No entities of type '{selected_type}' found." + ) + + def select_all_orphans(self): + for var in self.all_check_vars.values(): + var.set(False) + + orphan_entities = [] + for label in self.entity_list: + if label not in self.entity_data: + self.entity_data[label] = fetch_entity_details(label) + + related_nodes = self.entity_data[label].get("related_nodes", []) + edges = self.entity_data[label].get("edges", []) + + if not related_nodes and not edges: + self.all_check_vars[label].set(True) + orphan_entities.append(label) + + self.filter_var.set("") + self.current_page = 0 + self.filtered_entity_list = sorted(orphan_entities, key=lambda x: x.lower()) + self.create_entity_list() + + if not orphan_entities: + messagebox.showinfo( + "No Orphans Found", "No entities without relationships were found." + ) + + def show_selected_only(self): + selected_entities = [ + label for label, var in self.all_check_vars.items() if var.get() + ] + if not selected_entities: + messagebox.showinfo("No Selection", "No entities are currently selected.") + self.filter_var.set("") + self.current_page = 0 + self.filtered_entity_list = self.entity_list.copy() + self.create_entity_list() + return + + self.filter_var.set("") + self.current_page = 0 + self.filtered_entity_list = sorted(selected_entities, key=lambda x: x.lower()) + self.create_entity_list() + + def clear_all_selections(self): + for var in self.all_check_vars.values(): + var.set(False) + self.filter_var.set("") + self.current_page = 0 + self.filtered_entity_list = self.entity_list.copy() + self.first_entity_var.set("") + self.strategy_desc.set("join_unique") + self.strategy_srcid.set("join_unique") + self.create_entity_list() + + def create_right_panel(self): + self.right_panel.grid_rowconfigure(1, weight=1) + self.right_panel.grid_columnconfigure(0, weight=1) + + control_frame = ttk.Frame(self.right_panel) + control_frame.grid(row=0, column=0, sticky="ew", padx=10, pady=10) + control_frame.grid_columnconfigure(0, weight=0) + control_frame.grid_columnconfigure(1, weight=1) + + ttk.Label(control_frame, text="Target Entity:").grid( + row=0, column=0, sticky="w", padx=5, pady=2 + ) + self.target_entry = ttk.Combobox(control_frame, values=[], width=40) + self.target_entry.grid(row=0, column=1, sticky="ew", padx=5, pady=2) + self.target_entry.bind( + "", lambda event: self.copy_selected_text(self.target_entry) + ) + + ttk.Label(control_frame, text="Merge Strategy - Description:").grid( + row=1, column=0, sticky="w", padx=5, pady=2 + ) + self.strategy_desc = ttk.Combobox( + control_frame, values=["concatenate", "keep_first", "join_unique"] + ) + self.strategy_desc.set("join_unique") + self.strategy_desc.grid(row=1, column=1, sticky="ew", padx=5, pady=2) + + ttk.Label(control_frame, text="Merge Strategy - Source ID:").grid( + row=2, column=0, sticky="w", padx=5, pady=2 + ) + self.strategy_srcid = ttk.Combobox( + control_frame, values=["concatenate", "keep_first", "join_unique"] + ) + self.strategy_srcid.set("join_unique") + self.strategy_srcid.grid(row=2, column=1, sticky="ew", padx=5, pady=2) + + self.entity_type_button = ttk.Button( + control_frame, + text="Select Entity Type", + command=self.open_all_entity_types_modal, + ) + self.entity_type_button.grid(row=3, column=0, sticky="w", padx=5, pady=2) + self.entity_type = ttk.Combobox(control_frame, values=[], width=37) + self.entity_type.grid(row=3, column=1, sticky="ew", padx=5, pady=2) + self.entity_type.bind( + "", lambda event: self.copy_selected_text(self.entity_type) + ) + + info_label = ttk.Label( + control_frame, + text="Note: 'Keep First' strategy uses the selected radio button item.", + font=("TkDefaultFont", 8), + foreground="gray", + wraplength=300, + ) + info_label.grid(row=4, column=0, columnspan=2, sticky="w", padx=5, pady=10) + + button_row_frame = ttk.Frame(control_frame) + button_row_frame.grid(row=5, column=0, columnspan=2, sticky="ew", pady=10) + button_row_frame.grid_columnconfigure(0, weight=1) + button_row_frame.grid_columnconfigure(1, weight=1) + button_row_frame.grid_columnconfigure(2, weight=1) + + self.merge_button = ttk.Button( + button_row_frame, text="Merge Entities", command=self.submit_merge + ) + self.merge_button.grid(row=0, column=0, sticky="w", padx=(5, 5)) + + self.create_rel_button = ttk.Button( + button_row_frame, + text="Create Relationship", + command=self.open_create_relationship_modal, + ) + self.create_rel_button.grid(row=0, column=1, sticky="w", padx=(5, 5)) + + self.delete_button = ttk.Button( + button_row_frame, + text="Delete Entities", + command=lambda: asyncio.run(self.run_delete()), + ) + self.delete_button.grid(row=0, column=2, sticky="w", padx=(5, 5)) + + self.merge_button.grid_remove() + self.create_rel_button.grid_remove() + self.delete_button.grid_remove() + + self.description_area = ttk.Frame(self.right_panel) + self.description_area.grid( + row=1, column=0, sticky="nsew", padx=10, pady=(0, 10) + ) + self.description_area.grid_columnconfigure(0, weight=1) + self.description_area.grid_rowconfigure(0, weight=1) + + self.desc_canvas = tk.Canvas(self.description_area) + self.desc_scrollbar = ttk.Scrollbar( + self.description_area, orient="vertical", command=self.desc_canvas.yview + ) + self.desc_scrollable_frame = ttk.Frame(self.desc_canvas) + + self.desc_scrollable_frame.bind( + "", + lambda e: self.desc_canvas.configure( + scrollregion=self.desc_canvas.bbox("all") + ), + ) + self.desc_canvas.create_window( + (0, 0), window=self.desc_scrollable_frame, anchor="nw" + ) + self.desc_canvas.configure(yscrollcommand=self.desc_scrollbar.set) + + self.desc_canvas.grid(row=0, column=0, sticky="nsew") + self.desc_scrollbar.grid(row=0, column=1, sticky="ns") + + def _on_desc_mousewheel(event): + try: + if platform.system() == "Windows": + if hasattr(event, "delta") and event.delta: + self.desc_canvas.yview_scroll( + int(-1 * (event.delta / 120)), "units" + ) + elif platform.system() == "Darwin": + if hasattr(event, "delta") and event.delta: + self.desc_canvas.yview_scroll(int(-1 * event.delta), "units") + else: + if event.num == 4: + self.desc_canvas.yview_scroll(-1, "units") + elif event.num == 5: + self.desc_canvas.yview_scroll(1, "units") + elif hasattr(event, "delta") and event.delta: + self.desc_canvas.yview_scroll( + int(-1 * (event.delta / 120)), "units" + ) + except Exception: + pass + return "break" + + desc_scroll_events = [] + if platform.system() == "Windows": + desc_scroll_events = ["", ""] + elif platform.system() == "Darwin": + desc_scroll_events = ["", "", ""] + else: + desc_scroll_events = ["", "", ""] + + for event in desc_scroll_events: + try: + self.desc_canvas.bind(event, _on_desc_mousewheel) + except Exception: + pass + + def open_all_entity_types_modal(self): + all_types = set() + for label in self.entity_list: + if label not in self.entity_data: + self.entity_data[label] = fetch_entity_details(label) + entity_type = self.entity_data[label].get("type", "") + if entity_type and not entity_type.startswith("Error:"): + all_types.add(entity_type) + all_types = sorted(list(all_types)) + + if not all_types: + messagebox.showinfo( + "No Entity Types", "No entity types found for any entities." + ) + return + + modal = tk.Toplevel(self.root) + modal.title("Select Entity Type") + modal.transient(self.root) + modal.grab_set() + modal.protocol("WM_DELETE_WINDOW", modal.destroy) + + modal.grid_columnconfigure(0, weight=1) + modal.grid_rowconfigure(0, weight=1) + + list_frame = ttk.Frame(modal) + list_frame.grid(row=0, column=0, sticky="nsew", padx=10, pady=10) + list_frame.grid_rowconfigure(0, weight=1) + list_frame.grid_columnconfigure(0, weight=1) + + listbox = tk.Listbox(list_frame, font=("TkDefaultFont", 10), height=10) + listbox.grid(row=0, column=0, sticky="nsew") + scrollbar = ttk.Scrollbar(list_frame, orient="vertical", command=listbox.yview) + listbox.configure(yscrollcommand=scrollbar.set) + scrollbar.grid(row=0, column=1, sticky="ns") + + for entity_type in all_types: + listbox.insert(tk.END, entity_type) + + def on_double_click(event): + selection = listbox.curselection() + if selection: + selected_type = listbox.get(selection[0]) + self.entity_type.set(selected_type) + modal.destroy() + + listbox.bind("", on_double_click) + + button_frame = ttk.Frame(modal) + button_frame.grid(row=1, column=0, sticky="ew", padx=10, pady=10) + cancel_button = ttk.Button(button_frame, text="Cancel", command=modal.destroy) + cancel_button.pack(side="right") + + modal.geometry("300x300") + x = self.root.winfo_x() + (self.root.winfo_width() // 2) - 150 + y = self.root.winfo_y() + (self.root.winfo_height() // 2) - 150 + modal.geometry(f"+{x}+{y}") + + def update_selection(self): + selected = [label for label, var in self.all_check_vars.items() if var.get()] + + if not hasattr(self, "target_entry"): + return + + self.target_entry["values"] = selected + + types = set() + for label in selected: + if label not in self.entity_data: + self.entity_data[label] = fetch_entity_details(label) + + if self.entity_data[label].get("type") and not str( + self.entity_data[label]["type"] + ).startswith("Error:"): + types.add(self.entity_data[label]["type"]) + + current_entity_type = self.entity_type.get() + self.entity_type["values"] = sorted(list(types)) + + if current_entity_type and current_entity_type in self.entity_type["values"]: + self.entity_type.set(current_entity_type) + else: + self.entity_type.set("") + + if self.target_entry.get() not in selected: + self.target_entry.set("") + + if self.first_entity_var.get() and self.first_entity_var.get() not in selected: + self.first_entity_var.set("") + + self.update_button_visibility() + + def update_button_visibility(self): + selected_from_all = { + label for label, var in self.all_check_vars.items() if var.get() + } + + if ( + selected_from_all + and set(self.filtered_entity_list) == selected_from_all + and len(self.filtered_entity_list) == len(selected_from_all) + ): + self.merge_button.grid() + self.delete_button.grid() + else: + self.merge_button.grid_remove() + self.delete_button.grid_remove() + + if len(selected_from_all) == 2: + entity1, entity2 = list(selected_from_all) + if entity1 not in self.entity_data: + self.entity_data[entity1] = fetch_entity_details(entity1) + if entity2 not in self.entity_data: + self.entity_data[entity2] = fetch_entity_details(entity2) + + edges1 = self.entity_data[entity1].get("edges", []) + has_relationship = any( + (edge["source"] == entity1 and edge["target"] == entity2) + or (edge["source"] == entity2 and edge["target"] == entity1) + for edge in edges1 + ) + + if not has_relationship: + self.create_rel_button.grid() + else: + self.create_rel_button.grid_remove() + else: + self.create_rel_button.grid_remove() + + def calculate_tile_layout( + self, available_width, available_height, num_items, min_height=600 + ): + if num_items == 0: + return 0, 0, 0, 0 + + effective_min_width = max(300, (available_width // 2) - 10) + + cols = max(1, available_width // effective_min_width) + + if available_width >= 2 * effective_min_width and num_items > 1: + cols = 2 + elif num_items == 1: + cols = 1 + else: + cols = 1 + + cols = min(cols, num_items) + + rows = (num_items + cols - 1) // cols + + frame_width = available_width // cols + frame_height = max( + min_height, available_height // rows if rows > 0 else available_height + ) + + return cols, rows, frame_width, frame_height + + def toggle_descriptions(self): + any_open = bool(self.description_frames) + if any_open: + for frame in list(self.description_frames.values()): + frame.destroy() + self.description_frames.clear() + self.desc_header_label["text"] = "Show\nDesc" + self.desc_scrollable_frame.update_idletasks() + self.desc_canvas.configure(scrollregion=self.desc_canvas.bbox("all")) + else: + selected = [ + label for label, var in self.all_check_vars.items() if var.get() + ] + if not selected: + messagebox.showinfo( + "No Selection", + "Please select some entities first to show descriptions.", + ) + return + + self.root.update_idletasks() + self.desc_scrollable_frame.update_idletasks() + available_width = self.desc_scrollable_frame.winfo_width() + available_height = self.desc_scrollable_frame.winfo_height() + + if available_width < 100: + available_width = self.right_panel.winfo_width() - 20 + if available_width < 100: + available_width = 800 + if available_height < 100: + available_height = 600 + + cols, rows, frame_width, frame_height = self.calculate_tile_layout( + available_width, available_height, len(selected), min_height=600 + ) + + for c in range(cols): + self.desc_scrollable_frame.grid_columnconfigure(c, weight=1) + for r in range(rows): + self.desc_scrollable_frame.grid_rowconfigure(r, weight=1) + + for idx, label in enumerate(selected): + try: + if label not in self.entity_data: + self.entity_data[label] = fetch_entity_details(label) + + entity_details = self.entity_data[label] + desc = entity_details.get("desc", "No description found.") + entity_type = entity_details.get("type", "No type found.") + srcid = entity_details.get("srcid", "") + fpath = entity_details.get("fpath", "") + related_nodes = entity_details.get("related_nodes", []) + edges = entity_details.get("edges", []) + + row = idx // cols + col = idx % cols + + desc_frame = ttk.LabelFrame( + self.desc_scrollable_frame, text="", padding=5 + ) + desc_frame.config(width=frame_width - 4, height=frame_height - 4) + desc_frame.grid_propagate(False) + + desc_frame.grid(row=row, column=col, sticky="nsew", padx=2, pady=2) + + header_sub_frame = ttk.Frame(desc_frame) + header_sub_frame.pack(fill="x") + header_sub_frame.grid_columnconfigure(0, weight=1) + header_sub_frame.grid_columnconfigure(1, weight=0) + header_sub_frame.grid_columnconfigure(2, weight=0) + + ttk.Label( + header_sub_frame, text=label, font=("TkDefaultFont", 10, "bold") + ).grid(row=0, column=0, sticky="w") + + edit_button = ttk.Button( + header_sub_frame, + text="Edit Description", + command=lambda lbl=label: self.open_edit_description_modal(lbl), + ) + edit_button.grid(row=0, column=1, sticky="e", padx=(5, 5)) + + edit_rel_button = ttk.Button( + header_sub_frame, + text="Edit/Delete Relationships", + command=lambda lbl=label: self.open_edit_relationships_modal( + lbl + ), + ) + edit_rel_button.grid(row=0, column=2, sticky="e", padx=(0, 5)) + + text_frame = ttk.Frame(desc_frame) + text_frame.pack(fill="both", expand=True, pady=(5, 0)) + + text_widget = tk.Text( + text_frame, wrap="word", font=("TkDefaultFont", 9), height=15 + ) + text_scrollbar = ttk.Scrollbar( + text_frame, orient="vertical", command=text_widget.yview + ) + text_widget.configure(yscrollcommand=text_scrollbar.set) + + text_content_parts = [] + if entity_type and entity_type != "No type found.": + text_content_parts.append(f"Type: {entity_type}") + + if related_nodes: + text_content_parts.append( + f"Related Entities: {len(related_nodes)}" + ) + + if desc and desc != "No description found.": + text_content_parts.append(f"Description:\n{desc}") + + if srcid: + text_content_parts.append(f"Source ID:\n{srcid}") + + if fpath: + text_content_parts.append(f"File Path: {fpath}") + + if related_nodes: + for idx, node in enumerate(related_nodes, start=1): + node_id = node.get("id", "N/A") + node_desc = node.get("description", "") + node_type = node.get("type", "") + text_content_parts.append(f"\nRelated Entity {idx}:") + text_content_parts.append( + f"- {node_id} (Type: {node_type})" + ) + if node_desc and node_desc != "No description found.": + text_content_parts.append(f" Description: {node_desc}") + + if edges: + filtered_edges = [ + edge + for edge in edges + if edge.get("source") == label + or edge.get("target") == label + ] + if filtered_edges: + text_content_parts.append("\nRelationships:") + for edge in filtered_edges: + source = edge.get("source", "N/A") + target = edge.get("target", "N/A") + edge_desc = edge.get( + "description", "No description provided." + ) + edge_keywords = edge.get("keywords", "") + edge_weight = edge.get("weight", 1.0) + text_content_parts.append( + f"- From: {source}\n To: {target}\n Relation: {edge_desc}\n Weight: {edge_weight}" + ) + if edge_keywords: + text_content_parts.append( + f" Keywords: {edge_keywords}" + ) + + text_widget.insert("1.0", "\n\n".join(text_content_parts)) + text_widget.config(state="disabled") + + text_widget.pack(side="left", fill="both", expand=True) + text_scrollbar.pack(side="right", fill="y") + + self.description_frames[label] = desc_frame + + except Exception as e: + print(f"Error showing description for {label}: {e}") + + self.desc_scrollable_frame.update_idletasks() + self.desc_canvas.configure(scrollregion=self.desc_canvas.bbox("all")) + + if self.description_frames: + self.desc_header_label["text"] = "Hide\nDesc" + + def open_edit_description_modal(self, entity_label): + modal = tk.Toplevel(self.root) + modal.title(f"Edit Description for {entity_label}") + modal.transient(self.root) + modal.grab_set() + modal.protocol("WM_DELETE_WINDOW", modal.destroy) + + modal.grid_columnconfigure(0, weight=1) + modal.grid_rowconfigure(1, weight=1) + + current_description = self.entity_data.get(entity_label, {}).get("desc", "") + + ttk.Label( + modal, + text=f"Editing Description for: {entity_label}", + font=("TkDefaultFont", 10, "bold"), + ).grid(row=0, column=0, padx=10, pady=5, sticky="w") + + text_frame = ttk.Frame(modal) + text_frame.grid(row=1, column=0, sticky="nsew", padx=10, pady=5) + text_frame.grid_rowconfigure(0, weight=1) + text_frame.grid_columnconfigure(0, weight=1) + + description_text = tk.Text(text_frame, wrap="word", font=("TkDefaultFont", 10)) + description_text.insert("1.0", current_description) + + text_scrollbar = ttk.Scrollbar( + text_frame, orient="vertical", command=description_text.yview + ) + description_text.configure(yscrollcommand=text_scrollbar.set) + + description_text.pack(side="left", fill="both", expand=True) + text_scrollbar.pack(side="right", fill="y") + + button_frame = ttk.Frame(modal) + button_frame.grid(row=2, column=0, sticky="ew", padx=10, pady=10) + button_frame.grid_columnconfigure(0, weight=1) + + save_button = ttk.Button( + button_frame, + text="Save", + command=lambda: self.save_entity_description( + entity_label, description_text, modal + ), + ) + save_button.pack(side="right", padx=(5, 0)) + + cancel_button = ttk.Button(button_frame, text="Cancel", command=modal.destroy) + cancel_button.pack(side="right") + + self.root.update_idletasks() + modal.update_idletasks() + min_width = 400 + min_height = 300 + if modal.winfo_width() < min_width or modal.winfo_height() < min_height: + modal.geometry(f"{min_width}x{min_height}") + modal.update_idletasks() + + x = ( + self.root.winfo_x() + + (self.root.winfo_width() // 2) + - (modal.winfo_width() // 2) + ) + y = ( + self.root.winfo_y() + + (self.root.winfo_height() // 2) + - (modal.winfo_height() // 2) + ) + modal.geometry(f"+{x}+{y}") + + description_text.bind( + "", lambda event: self.copy_selected_text(description_text) + ) + + self.root.wait_window(modal) + + def open_create_relationship_modal(self): + selected = [label for label, var in self.all_check_vars.items() if var.get()] + if len(selected) != 2: + messagebox.showerror( + "Error", "Please select exactly two entities to create a relationship." + ) + return + + entity1, entity2 = selected + modal = tk.Toplevel(self.root) + modal.title("Create New Relationship") + modal.transient(self.root) + modal.grab_set() + modal.protocol("WM_DELETE_WINDOW", modal.destroy) + + modal.grid_columnconfigure(0, weight=1) + modal.grid_columnconfigure(1, weight=0) + modal.grid_rowconfigure(5, weight=0) + + ttk.Label( + modal, + text="Create Relationship Between Entities", + font=("TkDefaultFont", 10, "bold"), + ).grid(row=0, column=0, padx=10, pady=5, sticky="w") + + ttk.Label(modal, text="Source Entity:").grid( + row=1, column=0, sticky="w", padx=10, pady=2 + ) + source_var = tk.StringVar(value=entity1) + source_combobox = ttk.Combobox( + modal, textvariable=source_var, values=selected, state="readonly" + ) + source_combobox.grid(row=2, column=0, sticky="ew", padx=10, pady=2) + + ttk.Label(modal, text="Target Entity:").grid( + row=3, column=0, sticky="w", padx=2 + ) + target_var = tk.StringVar(value=entity2) + target_combobox = ttk.Combobox( + modal, textvariable=target_var, values=selected, state="readonly" + ) + target_combobox.grid(row=4, column=0, sticky="ew", padx=10, pady=2) + + ttk.Label(modal, text="Relationship Description:").grid( + row=5, column=0, sticky="w", padx=10, pady=2 + ) + desc_text = tk.Text(modal, wrap="word", font=("TkDefaultFont", 10), height=5) + desc_scrollbar = ttk.Scrollbar( + modal, orient="vertical", command=desc_text.yview + ) + desc_text.configure(yscrollcommand=desc_scrollbar.set) + desc_text.grid(row=6, column=0, sticky="nsew", padx=10, pady=2) + desc_scrollbar.grid(row=6, column=1, sticky="ns") + + ttk.Label(modal, text="Keywords (comma-separated):").grid( + row=7, column=0, sticky="w", padx=10, pady=2 + ) + keywords_entry = ttk.Entry(modal) + keywords_entry.grid(row=8, column=0, sticky="ew", padx=10, pady=2) + + ttk.Label(modal, text="Weight (1.0-10.0):").grid( + row=9, column=0, sticky="w", padx=10, pady=2 + ) + weight_var = tk.StringVar(value="7.0") + weight_entry = ttk.Entry(modal, textvariable=weight_var) + weight_entry.grid(row=10, column=0, sticky="ew", padx=10, pady=2) + + ttk.Label(modal, text="Source File ID:").grid( + row=11, column=0, sticky="w", padx=10, pady=2 + ) + source_file_id_entry = ttk.Entry(modal) + source_file_id_entry.grid(row=12, column=0, sticky="ew", padx=10, pady=2) + + button_frame = ttk.Frame(modal) + button_frame.grid(row=13, column=0, columnspan=2, sticky="ew", padx=10, pady=10) + button_frame.grid_columnconfigure(0, weight=1) + + save_button = ttk.Button( + button_frame, + text="Save", + command=lambda: self.save_new_relationship( + source_var.get(), + target_var.get(), + desc_text, + keywords_entry.get(), + weight_var.get(), + source_file_id_entry.get(), + modal, + ), + ) + save_button.pack(side="right", padx=(5, 0)) + + cancel_button = ttk.Button(button_frame, text="Cancel", command=modal.destroy) + cancel_button.pack(side="right") + + modal.geometry("500x450") + x = self.root.winfo_x() + (self.root.winfo_width() // 2) - 250 + y = self.root.winfo_y() + (self.root.winfo_height() // 2) - 300 + modal.geometry(f"+{x}+{y}") + + desc_text.bind("", lambda event: self.copy_selected_text(desc_text)) + + def save_new_relationship( + self, source_id, target_id, desc_text, keywords, weight, source_file_id, modal + ): + description = desc_text.get("1.0", tk.END).strip() + if ( + not source_id + or not target_id + or not description + or not keywords + or not weight + or not source_file_id + ): + messagebox.showerror("Error", "All fields are required.") + return + + try: + weight_float = float(weight) + if not (1.0 <= weight_float <= 10.0): + messagebox.showerror("Error", "Weight must be between 1.0 and 10.0.") + return + except ValueError: + messagebox.showerror("Error", "Weight must be a valid number.") + return + + if create_relationship_api( + source_id, target_id, description, keywords, weight_float, source_file_id + ): + trigger_server_refresh() + self.entity_data[source_id] = fetch_entity_details(source_id) + self.entity_data[target_id] = fetch_entity_details(target_id) + if self.description_frames: + for frame in list(self.description_frames.values()): + frame.destroy() + self.description_frames.clear() + self.toggle_descriptions() + modal.destroy() + self.create_entity_list() + + def open_edit_relationships_modal(self, entity_label): + modal = tk.Toplevel(self.root) + modal.title(f"Edit Relationships for {entity_label}") + modal.transient(self.root) + modal.grab_set() + modal.protocol("WM_DELETE_WINDOW", modal.destroy) + + modal.grid_columnconfigure(0, weight=1) + modal.grid_rowconfigure(2, weight=1) + + ttk.Label( + modal, + text=f"Select Relationship for: {entity_label}", + font=("TkDefaultFont", 10, "bold"), + ).grid(row=0, column=0, padx=10, pady=5, sticky="w") + + entity_details = self.entity_data.get(entity_label, {}) + edges = entity_details.get("edges", []) + filtered_edges = [ + edge + for edge in edges + if edge.get("source") == entity_label or edge.get("target") == entity_label + ] + + if not filtered_edges: + ttk.Label( + modal, text="No relationships found.", font=("TkDefaultFont", 10) + ).grid(row=1, column=0, padx=10, pady=5, sticky="w") + button_frame = ttk.Frame(modal) + button_frame.grid(row=2, column=0, sticky="ew", padx=10, pady=10) + ttk.Button(button_frame, text="Close", command=modal.destroy).pack( + side="right" + ) + return + + relationship_options = [ + f"From: {edge['source']} To: {edge['target']}" for edge in filtered_edges + ] + selected_relationship = tk.StringVar() + relationship_combobox = ttk.Combobox( + modal, + textvariable=selected_relationship, + values=relationship_options, + state="readonly", + ) + relationship_combobox.grid(row=1, column=0, sticky="ew", padx=10, pady=5) + if relationship_options: + relationship_combobox.set(relationship_options[0]) + + edit_frame = ttk.Frame(modal) + edit_frame.grid(row=2, column=0, sticky="nsew", padx=10, pady=5) + edit_frame.grid_columnconfigure(0, weight=1) + edit_frame.grid_rowconfigure(1, weight=1) + + ttk.Label(edit_frame, text="Relationship Description:").grid( + row=0, column=0, sticky="w", padx=5, pady=2 + ) + desc_text = tk.Text( + edit_frame, wrap="word", font=("TkDefaultFont", 10), height=5 + ) + desc_scrollbar = ttk.Scrollbar( + edit_frame, orient="vertical", command=desc_text.yview + ) + desc_text.configure(yscrollcommand=desc_scrollbar.set) + desc_text.grid(row=1, column=0, sticky="nsew", padx=5, pady=2) + desc_scrollbar.grid(row=1, column=1, sticky="ns") + + ttk.Label(edit_frame, text="Keywords (comma-separated):").grid( + row=2, column=0, sticky="w", padx=5, pady=2 + ) + keywords_entry = ttk.Entry(edit_frame) + keywords_entry.grid(row=3, column=0, sticky="ew", padx=5, pady=2) + + ttk.Label(edit_frame, text="Weight (1.0-10.0):").grid( + row=4, column=0, sticky="w", padx=5, pady=2 + ) + weight_var = tk.StringVar() + weight_entry = ttk.Entry(edit_frame, textvariable=weight_var) + weight_entry.grid(row=5, column=0, sticky="ew", padx=5, pady=2) + + def update_fields(*args): + selected_idx = relationship_combobox.current() + if selected_idx >= 0: + edge = filtered_edges[selected_idx] + desc_text.delete("1.0", tk.END) + desc_text.insert("1.0", edge.get("description", "")) + keywords_entry.delete(0, tk.END) + keywords_entry.insert(0, edge.get("keywords", "")) + weight_var.set(str(edge.get("weight", 1.0))) + + selected_relationship.trace("w", update_fields) + update_fields() + + button_frame = ttk.Frame(modal) + button_frame.grid(row=3, column=0, sticky="ew", padx=10, pady=10) + button_frame.grid_columnconfigure(0, weight=1) + + save_button = ttk.Button( + button_frame, + text="Save", + command=lambda: self.save_relationship( + entity_label, + filtered_edges, + relationship_combobox.current(), + desc_text, + keywords_entry, + weight_var.get(), + modal, + ), + ) + save_button.pack(side="right", padx=(5, 0)) + + cancel_button = ttk.Button(button_frame, text="Cancel", command=modal.destroy) + cancel_button.pack(side="right") + + delete_button = ttk.Button( + button_frame, + text="Delete Relationship", + command=lambda: asyncio.run( + self.delete_relationship( + entity_label, filtered_edges, relationship_combobox.current(), modal + ) + ), + ) + delete_button.pack(side="left", padx=(0, 5)) + + self.root.update_idletasks() + modal.update_idletasks() + min_width = 500 + min_height = 400 + modal.geometry(f"{min_width}x{min_height}") + + x = ( + self.root.winfo_x() + + (self.root.winfo_width() // 2) + - (modal.winfo_width() // 2) + ) + y = ( + self.root.winfo_y() + + (self.root.winfo_height() // 2) + - (modal.winfo_height() // 2) + ) + modal.geometry(f"+{x}+{y}") + + self.root.wait_window(modal) + + def save_relationship( + self, + entity_label, + edges, + selected_idx, + desc_text, + keywords_entry, + weight, + modal, + ): + if selected_idx < 0: + messagebox.showerror("Error", "No relationship selected.") + return + + edge = edges[selected_idx] + source_id = edge["source"] + target_id = edge["target"] + new_description = desc_text.get("1.0", tk.END).strip() + new_keywords = keywords_entry.get().strip() + + try: + weight_float = float(weight) + if not (1.0 <= weight_float <= 10.0): + messagebox.showerror("Error", "Weight must be between 1.0 and 10.0.") + return + except ValueError: + messagebox.showerror("Error", "Weight must be a valid number.") + return + + if update_relationship_api( + source_id, target_id, new_description, new_keywords, weight_float + ): + for e in self.entity_data[entity_label]["edges"]: + if e["source"] == source_id and e["target"] == target_id: + e["description"] = new_description + e["keywords"] = new_keywords + e["weight"] = weight_float + break + + if self.description_frames: + for frame in list(self.description_frames.values()): + frame.destroy() + self.description_frames.clear() + self.toggle_descriptions() + + modal.destroy() + + async def delete_relationship(self, entity_label, edges, selected_idx, modal): + if selected_idx < 0 or selected_idx >= len(edges): + messagebox.showerror("Error", "No relationship selected to delete.") + return + + edge = edges[selected_idx] + source_id = edge["source"] + target_id = edge["target"] + + confirm = messagebox.askyesno( + "Confirm Deletion", + f"Are you sure you want to delete the relationship from:\n\n{source_id}\n\nto\n\n{target_id}?", + ) + if not confirm: + return + + try: + rag = await initialize_rag() + await rag.adelete_by_relation(source_id, target_id) + + messagebox.showinfo( + "Success", f"Deleted relationship from '{source_id}' to '{target_id}'." + ) + + print("Triggering LightRAG server refresh...") + if not trigger_server_refresh(): + print("Server data refresh failed or server not running.") + else: + print("Server refresh attempted.") + + if entity_label in self.entity_data: + self.entity_data[entity_label] = fetch_entity_details(entity_label) + + modal.destroy() + self.open_edit_relationships_modal(entity_label) + + except Exception as e: + messagebox.showerror("Error", f"Failed to delete relationship:\n{e}") + finally: + if "rag" in locals(): + await rag.finalize_storages() + + def save_entity_description( + self, entity_label, description_text_widget, modal_window + ): + new_description = description_text_widget.get("1.0", tk.END).strip() + + if update_entity_description_api(entity_label, new_description): + if entity_label in self.entity_data: + self.entity_data[entity_label]["desc"] = new_description + + if self.description_frames: + for frame in list(self.description_frames.values()): + frame.destroy() + self.description_frames.clear() + self.toggle_descriptions() + + modal_window.destroy() + + def submit_merge(self): + if not self.entity_type.get() or not self.target_entry.get(): + messagebox.showerror( + "Missing info", "Please select a target entity and entity type." + ) + return + + selected = [label for label, var in self.all_check_vars.items() if var.get()] + + if not selected: + messagebox.showerror("No entities", "Select at least one source entity.") + return + + selected_from_filtered = { + label + for label in self.filtered_entity_list + if self.all_check_vars.get(label, tk.BooleanVar()).get() + } + if not ( + set(selected) == selected_from_filtered + and len(selected) == len(selected_from_filtered) + ): + messagebox.showerror( + "Operation Not Allowed", + "Merge operation can only be performed when 'Selected Only' mode is active and the displayed list exactly matches the selected entities.", + ) + return + + for frame in list(self.description_frames.values()): + frame.destroy() + self.description_frames.clear() + self.desc_header_label["text"] = "Show\nDesc" + + strategy = { + "description": self.strategy_desc.get(), + "source_id": self.strategy_srcid.get(), + } + + if ( + strategy["description"] == "keep_first" + or strategy["source_id"] == "keep_first" + ): + first_entity = self.first_entity_var.get() + if not first_entity: + messagebox.showerror( + "Missing Selection", + "Please select which entity should be 'first' using the radio buttons when using 'keep_first' strategy.", + ) + return + if first_entity not in selected: + messagebox.showerror( + "Invalid Selection", + "The selected 'first' entity must be in the list of selected entities.", + ) + return + selected = [first_entity] + [e for e in selected if e != first_entity] + + asyncio.run( + self.run_merge( + selected, + self.target_entry.get(), + strategy, + etype=self.entity_type.get(), + ) + ) + + async def run_merge(self, sources, target, strategy, etype): + rag = await initialize_rag() + try: + await rag.amerge_entities( + source_entities=sources, + target_entity=target, + merge_strategy=strategy, + target_entity_data={"entity_type": etype}, + ) + messagebox.showinfo("Success", f"Entities merged into '{target}'") + + print("Refreshing LightRAG server data from disk...") + if not trigger_server_refresh(): + print( + "Server data refresh failed or server not running. Manual restart might still be needed if changes don't appear." + ) + else: + print("Server refresh attempted.") + + self.entity_data.clear() + + newly_fetched_entities = fetch_entities() + + new_all_check_vars = {} + for entity in newly_fetched_entities: + new_all_check_vars[entity] = self.all_check_vars.get( + entity, tk.BooleanVar() + ) + self.all_check_vars = new_all_check_vars + self.entity_list = newly_fetched_entities + + self.filter_var.set("") + self.current_page = 0 + self.filtered_entity_list = self.entity_list.copy() + for var in self.all_check_vars.values(): + var.set(False) + self.first_entity_var.set("") + + self.create_entity_list() + + except Exception as e: + messagebox.showerror("Error", str(e)) + finally: + await rag.finalize_storages() + + async def run_delete(self): + selected_for_deletion = [ + label for label, var in self.all_check_vars.items() if var.get() + ] + + if not selected_for_deletion: + messagebox.showinfo("No entities", "No entities are selected for deletion.") + return + + selected_from_filtered = { + label + for label in self.filtered_entity_list + if self.all_check_vars.get(label, tk.BooleanVar()).get() + } + if not ( + set(selected_for_deletion) == selected_from_filtered + and len(selected_for_deletion) == len(selected_from_filtered) + ): + messagebox.showerror( + "Operation Not Allowed", + "Delete operation can only be performed when 'Selected Only' mode is active and the displayed list exactly matches the selected entities.", + ) + return + + if len(selected_for_deletion) > 1: + confirm_message = ( + f"Are you sure you want to delete the following {len(selected_for_deletion)} entities?\n\n" + + "\n".join(selected_for_deletion[:10]) + + ("..." if len(selected_for_deletion) > 10 else "") + ) + else: + confirm_message = ( + f"Are you sure you want to delete '{selected_for_deletion[0]}'?" + ) + + if not messagebox.askyesno("Confirm Deletion", confirm_message): + return + + for frame in list(self.description_frames.values()): + frame.destroy() + self.description_frames.clear() + self.desc_header_label["text"] = "Show\nDesc" + + rag = await initialize_rag() + try: + success_count = 0 + fail_count = 0 + failed_entities = [] + + for entity_to_delete in selected_for_deletion: + print(f"Attempting to delete entity: {entity_to_delete}") + try: + await rag.adelete_by_entity(entity_to_delete) + print(f"Successfully deleted: {entity_to_delete}") + success_count += 1 + except Exception as e: + print(f"Failed to delete {entity_to_delete}: {e}") + fail_count += 1 + failed_entities.append(entity_to_delete) + + if success_count > 0: + messagebox.showinfo( + "Deletion Complete", + f"Successfully deleted {success_count} entities. Failed to delete {fail_count} entities. Please check logs for details.", + ) + if failed_entities: + messagebox.showerror( + "Deletion Errors", + "Failed to delete:\n" + "\n".join(failed_entities), + ) + + print("Refreshing LightRAG server data from disk after deletion...") + if not trigger_server_refresh(): + print( + "Server data refresh failed or server not running. Manual restart might still be needed if changes don't appear." + ) + else: + print("Server refresh attempted.") + + self.entity_data.clear() + newly_fetched_entities = fetch_entities() + + new_all_check_vars = {} + for entity in newly_fetched_entities: + new_all_check_vars[entity] = self.all_check_vars.get( + entity, tk.BooleanVar() + ) + self.all_check_vars = new_all_check_vars + self.entity_list = newly_fetched_entities + + self.filter_var.set("") + self.current_page = 0 + self.filtered_entity_list = self.entity_list.copy() + for var in self.all_check_vars.values(): + var.set(False) + self.first_entity_var.set("") + + self.create_entity_list() + + elif fail_count > 0: + messagebox.showerror( + "Deletion Failed", + "No entities were successfully deleted. Please check logs for details.", + ) + + except Exception as e: + messagebox.showerror( + "Error", f"An unexpected error occurred during deletion: {e}" + ) + finally: + await rag.finalize_storages() + + def on_closing(self): + self.save_window_config() + self.root.destroy() + + +if __name__ == "__main__": + root = tk.Tk() + + # === ADD THIS BLOCK HERE === + import tkinter.font as tkfont + + # Force a thicker, larger, more readable default font + default_font = tkfont.nametofont("TkDefaultFont") + default_font.configure(family="DejaVu Sans", size=12, weight="bold") + + # Also fix the other standard Tk fonts for consistency + tkfont.nametofont("TkTextFont").configure(family="DejaVu Sans", size=12) + tkfont.nametofont("TkFixedFont").configure(family="DejaVu Sans Mono", size=12) + + # Optional: if you prefer a different thick font that's likely available + # default_font.configure(family="Liberation Sans", size=12, weight="bold") + # or + # default_font.configure(family="Arial", size=12, weight="bold") + # =========================== + + app = MergeGUI(root) + root.mainloop() diff --git a/jrs/_notes/.env b/jrs/_notes/.env new file mode 100644 index 0000000000..3342000cc6 --- /dev/null +++ b/jrs/_notes/.env @@ -0,0 +1,508 @@ +### You must have a .env file in LightRAG directory. +### This is a sample .env file that works for me. + +########################### +### Server Configuration +########################### +HOST=0.0.0.0 +PORT=9621 +WEBUI_TITLE='Knowledge Graph' +WEBUI_DESCRIPTION="Simple and Fast Graph Based RAG System" +# WORKERS=2 +### gunicorn worker timeout(as default LLM request timeout if LLM_TIMEOUT is not set) +# TIMEOUT=150 +# CORS_ORIGINS=http://localhost:3000,http://localhost:8080 + +### Optional SSL Configuration +# SSL=true +# SSL_CERTFILE=/path/to/cert.pem +# SSL_KEYFILE=/path/to/key.pem + +### Directory Configuration (defaults to current working directory) +### Default value is ./inputs and ./rag_storage +INPUT_DIR= "Directory where your source docs are located" +WORKING_DIR= "/home/js/LightRAG/jrs/work/my_work_dir" + +### Tiktoken cache directory (Store cached files in this folder for offline deployment) +# TIKTOKEN_CACHE_DIR=/app/data/tiktoken + +### Ollama Emulating Model and Tag +# OLLAMA_EMULATING_MODEL_NAME=lightrag +OLLAMA_EMULATING_MODEL_TAG=latest + +### Max nodes for graph retrieval (Ensure WebUI local settings are also updated, which is limited to this value) +# MAX_GRAPH_NODES=1000 + +### Logging level +# LOG_LEVEL=INFO +# VERBOSE=False +# LOG_MAX_BYTES=10485760 +# LOG_BACKUP_COUNT=5 +### Logfile location (defaults to current working directory) +# LOG_DIR=/path/to/log/directory + +##################################### +### Login and API-Key Configuration +##################################### +# AUTH_ACCOUNTS='admin:admin123,user1:pass456' +# TOKEN_SECRET=Your-Key-For-LightRAG-API-Server +# TOKEN_EXPIRE_HOURS=48 +# GUEST_TOKEN_EXPIRE_HOURS=24 +# JWT_ALGORITHM=HS256 + +### API-Key to access LightRAG Server API +### Use this key in HTTP requests with the 'X-API-Key' header +### Example: curl -H "X-API-Key: your-secure-api-key-here" http://localhost:9621/query +# LIGHTRAG_API_KEY=your-secure-api-key-here +# WHITELIST_PATHS=/health,/api/* + +###################################################################################### +### Query Configuration +### +### How to control the context length sent to LLM: +### MAX_ENTITY_TOKENS + MAX_RELATION_TOKENS < MAX_TOTAL_TOKENS +### Chunk_Tokens = MAX_TOTAL_TOKENS - Actual_Entity_Tokens - Actual_Relation_Tokens +###################################################################################### +# LLM response cache for query (Not valid for streaming response) +ENABLE_LLM_CACHE=true +# COSINE_THRESHOLD=0.2 +### Number of entities or relations retrieved from KG +# TOP_K=40 +### Maximum number or chunks for naive vector search +# CHUNK_TOP_K=20 +### control the actual entities send to LLM +# MAX_ENTITY_TOKENS=6000 +### control the actual relations send to LLM +# MAX_RELATION_TOKENS=8000 +### control the maximum tokens send to LLM (include entities, relations and chunks) +# MAX_TOTAL_TOKENS=30000 + +### chunk selection strategies +### VECTOR: Pick KG chunks by vector similarity, delivered chunks to the LLM aligning more closely with naive retrieval +### WEIGHT: Pick KG chunks by entity and chunk weight, delivered more solely KG related chunks to the LLM +### If reranking is enabled, the impact of chunk selection strategies will be diminished. +# KG_CHUNK_PICK_METHOD=VECTOR + +######################################################### +### Reranking configuration +### RERANK_BINDING type: null, cohere, jina, aliyun +### For rerank model deployed by vLLM use cohere binding +######################################################### +RERANK_BINDING=null +### Enable rerank by default in query params when RERANK_BINDING is not null +# RERANK_BY_DEFAULT=True +### rerank score chunk filter(set to 0.0 to keep all chunks, 0.6 or above if LLM is not strong enough) +# MIN_RERANK_SCORE=0.0 + +### For local deployment with vLLM +# RERANK_MODEL=BAAI/bge-reranker-v2-m3 +# RERANK_BINDING_HOST=http://localhost:8000/v1/rerank +# RERANK_BINDING_API_KEY=your_rerank_api_key_here + +### Default value for Cohere AI +# RERANK_MODEL=rerank-v3.5 +# RERANK_BINDING_HOST=https://api.cohere.com/v2/rerank +# RERANK_BINDING_API_KEY=your_rerank_api_key_here +### Cohere rerank chunking configuration (useful for models with token limits like ColBERT) +# RERANK_ENABLE_CHUNKING=true +# RERANK_MAX_TOKENS_PER_DOC=480 + +### Default value for Jina AI +# RERANK_MODEL=jina-reranker-v2-base-multilingual +# RERANK_BINDING_HOST=https://api.jina.ai/v1/rerank +# RERANK_BINDING_API_KEY=your_rerank_api_key_here + +### Default value for Aliyun +# RERANK_MODEL=gte-rerank-v2 +# RERANK_BINDING_HOST=https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank +# RERANK_BINDING_API_KEY=your_rerank_api_key_here + +######################################## +### Document processing configuration +######################################## +ENABLE_LLM_CACHE_FOR_EXTRACT=true + +### Document processing output language: English, Chinese, French, German ... +SUMMARY_LANGUAGE=English + +### PDF decryption password for protected PDF files +# PDF_DECRYPT_PASSWORD=your_pdf_password_here + +### Entity types that the LLM will attempt to recognize +# ENTITY_TYPES='["Person", "Creature", "Organization", "Location", "Event", "Concept", "Method", "Content", "Data", "Artifact", "NaturalObject"]' + +### Chunk size for document splitting, 500~1500 is recommended +# CHUNK_SIZE=1200 +# CHUNK_OVERLAP_SIZE=100 + +### Number of summary segments or tokens to trigger LLM summary on entity/relation merge (at least 3 is recommended) +# FORCE_LLM_SUMMARY_ON_MERGE=8 +### Max description token size to trigger LLM summary +# SUMMARY_MAX_TOKENS = 1200 +### Recommended LLM summary output length in tokens +# SUMMARY_LENGTH_RECOMMENDED_=600 +### Maximum context size sent to LLM for description summary +# SUMMARY_CONTEXT_SIZE=12000 + +### control the maximum chunk_ids stored in vector and graph db +# MAX_SOURCE_IDS_PER_ENTITY=300 +# MAX_SOURCE_IDS_PER_RELATION=300 +### control chunk_ids limitation method: FIFO, KEEP +### FIFO: First in first out +### KEEP: Keep oldest (less merge action and faster) +# SOURCE_IDS_LIMIT_METHOD=FIFO + +# Maximum number of file paths stored in entity/relation file_path field (For displayed only, does not affect query performance) +# MAX_FILE_PATHS=100 + +### maximum number of related chunks per source entity or relation +### The chunk picker uses this value to determine the total number of chunks selected from KG(knowledge graph) +### Higher values increase re-ranking time +# RELATED_CHUNK_NUMBER=5 + +############################### +### Concurrency Configuration +############################### +### Max concurrency requests of LLM (for both query and document processing) +MAX_ASYNC=2 +### Number of parallel processing documents(between 2~10, MAX_ASYNC/3 is recommended) +MAX_PARALLEL_INSERT=2 +### Max concurrency requests for Embedding +### I dropped this from a default of 16 to 2 because of rate limit errors with OpenAI +EMBEDDING_FUNC_MAX_ASYNC=2 +### Num of chunks send to Embedding in single request +# EMBEDDING_BATCH_NUM=10 + +########################################################################### +### LLM Configuration +### LLM_BINDING type: openai, ollama, lollms, azure_openai, aws_bedrock, gemini +### LLM_BINDING_HOST: host only for Ollama, endpoint for other LLM service +### If LightRAG deployed in Docker: +### uses host.docker.internal instead of localhost in LLM_BINDING_HOST +########################################################################### +### LLM request timeout setting for all llm (0 means no timeout for Ollma) +# LLM_TIMEOUT=180 + +LLM_BINDING=openai + +# Larger and more expensive LLM which likely gives better results. +#LLM_MODEL=gpt-4o +# Smaller and less expensive LLM which likely does not respond as well as the larger model. +# In anycase, this model gives very good performance for the money spent. +# At some point I will start to use local LLMs to avoid expenses entirely. +# Electricity and equipment costs will go up, of course, when using a local LLM. +LLM_MODEL=gpt-4o-mini + + +LLM_BINDING_HOST=https://api.openai.com/v1 +# Your_API_Key below if not specified already in your .bashrc file +# LLM_BINDING_API_KEY=your_api_key + +### Azure OpenAI example +### Use deployment name as model name or set AZURE_OPENAI_DEPLOYMENT instead +# AZURE_OPENAI_API_VERSION=2024-08-01-preview +# LLM_BINDING=azure_openai +# LLM_BINDING_HOST=https://xxxx.openai.azure.com/ +# LLM_BINDING_API_KEY=your_api_key + +# LLM_MODEL=my-gpt-mini-deployment + +### Openrouter example +# LLM_MODEL=google/gemini-2.5-flash +# LLM_BINDING_HOST=https://openrouter.ai/api/v1 +# LLM_BINDING_API_KEY=your_api_key +# LLM_BINDING=openai + +### Gemini example +# LLM_BINDING=gemini +# LLM_MODEL=gemini-flash-latest +# LLM_BINDING_API_KEY=your_gemini_api_key +# LLM_BINDING_HOST=https://generativelanguage.googleapis.com + +### use the following command to see all support options for OpenAI, azure_openai or OpenRouter +### lightrag-server --llm-binding gemini --help +### Gemini Specific Parameters +# GEMINI_LLM_MAX_OUTPUT_TOKENS=9000 +# GEMINI_LLM_TEMPERATURE=0.7 +### Enable Thinking +# GEMINI_LLM_THINKING_CONFIG='{"thinking_budget": -1, "include_thoughts": true}' +### Disable Thinking +# GEMINI_LLM_THINKING_CONFIG='{"thinking_budget": 0, "include_thoughts": false}' + +### use the following command to see all support options for OpenAI, azure_openai or OpenRouter +### lightrag-server --llm-binding openai --help +### OpenAI Specific Parameters +# OPENAI_LLM_REASONING_EFFORT=minimal +### OpenRouter Specific Parameters +# OPENAI_LLM_EXTRA_BODY='{"reasoning": {"enabled": false}}' +### Qwen3 Specific Parameters deploy by vLLM +# OPENAI_LLM_EXTRA_BODY='{"chat_template_kwargs": {"enable_thinking": false}}' + +### OpenAI Compatible API Specific Parameters +### Increased temperature values may mitigate infinite inference loops in certain LLM, such as Qwen3-30B. +# OPENAI_LLM_TEMPERATURE=0.9 +### Set the max_tokens to mitigate endless output of some LLM (less than LLM_TIMEOUT * llm_output_tokens/second, i.e. 9000 = 180s * 50 tokens/s) +### Typically, max_tokens does not include prompt content +### For vLLM/SGLang deployed models, or most of OpenAI compatible API provider +# OPENAI_LLM_MAX_TOKENS=9000 +### For OpenAI o1-mini or newer modles utilizes max_completion_tokens instead of max_tokens +OPENAI_LLM_MAX_COMPLETION_TOKENS=9000 + +### use the following command to see all support options for Ollama LLM +### lightrag-server --llm-binding ollama --help +### Ollama Server Specific Parameters +### OLLAMA_LLM_NUM_CTX must be provided, and should at least larger than MAX_TOTAL_TOKENS + 2000 +OLLAMA_LLM_NUM_CTX=32768 +### Set the max_output_tokens to mitigate endless output of some LLM (less than LLM_TIMEOUT * llm_output_tokens/second, i.e. 9000 = 180s * 50 tokens/s) +# OLLAMA_LLM_NUM_PREDICT=9000 +### Stop sequences for Ollama LLM +# OLLAMA_LLM_STOP='["", "<|EOT|>"]' + +### Bedrock Specific Parameters +# BEDROCK_LLM_TEMPERATURE=1.0 + +####################################################################################### +### Embedding Configuration (Should not be changed after the first file processed) +### EMBEDDING_BINDING: ollama, openai, azure_openai, jina, lollms, aws_bedrock +### EMBEDDING_BINDING_HOST: host only for Ollama, endpoint for other Embedding service +### If LightRAG deployed in Docker: +### uses host.docker.internal instead of localhost in EMBEDDING_BINDING_HOST +####################################################################################### +# EMBEDDING_TIMEOUT=30 + +### Control whether to send embedding_dim parameter to embedding API +### IMPORTANT: Jina ALWAYS sends dimension parameter (API requirement) - this setting is ignored for Jina +### For OpenAI: Set to 'true' to enable dynamic dimension adjustment +### For OpenAI: Set to 'false' (default) to disable sending dimension parameter +### Note: Automatically ignored for backends that don't support dimension parameter (e.g., Ollama) + +# Ollama embedding +# EMBEDDING_BINDING=ollama +# EMBEDDING_MODEL=bge-m3:latest +# EMBEDDING_DIM=1024 +# EMBEDDING_BINDING_API_KEY=your_api_key +### If LightRAG deployed in Docker uses host.docker.internal instead of localhost +# EMBEDDING_BINDING_HOST=http://localhost:11434 + +### OpenAI compatible embedding +EMBEDDING_BINDING=openai +EMBEDDING_MODEL=text-embedding-3-large +EMBEDDING_DIM=3072 +EMBEDDING_SEND_DIM=false +EMBEDDING_TOKEN_LIMIT=8192 +EMBEDDING_BINDING_HOST=https://api.openai.com/v1 +# Your_API_Key below if not specified already in your .bashrc file +# EMBEDDING_BINDING_API_KEY=your_api_key + +### Optional for Azure embedding +### Use deployment name as model name or set AZURE_EMBEDDING_DEPLOYMENT instead +# AZURE_EMBEDDING_API_VERSION=2024-08-01-preview +# EMBEDDING_BINDING=azure_openai +# EMBEDDING_BINDING_HOST=https://xxxx.openai.azure.com/ +# EMBEDDING_API_KEY=your_api_key +# EMBEDDING_MODEL==my-text-embedding-3-large-deployment +# EMBEDDING_DIM=3072 + +### Gemini embedding +# EMBEDDING_BINDING=gemini +# EMBEDDING_MODEL=gemini-embedding-001 +# EMBEDDING_DIM=1536 +# EMBEDDING_TOKEN_LIMIT=2048 +# EMBEDDING_BINDING_HOST=https://generativelanguage.googleapis.com +# EMBEDDING_BINDING_API_KEY=your_api_key +### Gemini embedding requires sending dimension to server +# EMBEDDING_SEND_DIM=true + +### Jina AI Embedding +# EMBEDDING_BINDING=jina +# EMBEDDING_BINDING_HOST=https://api.jina.ai/v1/embeddings +# EMBEDDING_MODEL=jina-embeddings-v4 +# EMBEDDING_DIM=2048 +# EMBEDDING_BINDING_API_KEY=your_api_key + +### Optional for Ollama embedding +OLLAMA_EMBEDDING_NUM_CTX=8192 +### use the following command to see all support options for Ollama embedding +### lightrag-server --embedding-binding ollama --help + +#################################################################### +### WORKSPACE sets workspace name for all storage types +### for the purpose of isolating data from LightRAG instances. +### Valid workspace name constraints: a-z, A-Z, 0-9, and _ +#################################################################### +# WORKSPACE=space1 + +############################ +### Data storage selection +############################ +### Default storage (Recommended for small scale deployment) +# LIGHTRAG_KV_STORAGE=JsonKVStorage +# LIGHTRAG_DOC_STATUS_STORAGE=JsonDocStatusStorage +# LIGHTRAG_GRAPH_STORAGE=NetworkXStorage +# LIGHTRAG_VECTOR_STORAGE=NanoVectorDBStorage + +### Redis Storage (Recommended for production deployment) +# LIGHTRAG_KV_STORAGE=RedisKVStorage +# LIGHTRAG_DOC_STATUS_STORAGE=RedisDocStatusStorage + +### Vector Storage (Recommended for production deployment) +# LIGHTRAG_VECTOR_STORAGE=MilvusVectorDBStorage +# LIGHTRAG_VECTOR_STORAGE=QdrantVectorDBStorage +# LIGHTRAG_VECTOR_STORAGE=FaissVectorDBStorage + +### Graph Storage (Recommended for production deployment) +# LIGHTRAG_GRAPH_STORAGE=Neo4JStorage +# LIGHTRAG_GRAPH_STORAGE=MemgraphStorage + +### PostgreSQL +# LIGHTRAG_KV_STORAGE=PGKVStorage +# LIGHTRAG_DOC_STATUS_STORAGE=PGDocStatusStorage +# LIGHTRAG_GRAPH_STORAGE=PGGraphStorage +# LIGHTRAG_VECTOR_STORAGE=PGVectorStorage + +### MongoDB (Vector storage only available on Atlas Cloud) +# LIGHTRAG_KV_STORAGE=MongoKVStorage +# LIGHTRAG_DOC_STATUS_STORAGE=MongoDocStatusStorage +# LIGHTRAG_GRAPH_STORAGE=MongoGraphStorage +# LIGHTRAG_VECTOR_STORAGE=MongoVectorDBStorage + +### PostgreSQL Configuration +POSTGRES_HOST=localhost +POSTGRES_PORT=5432 +POSTGRES_USER=your_username +POSTGRES_PASSWORD='your_password' +POSTGRES_DATABASE=your_database +POSTGRES_MAX_CONNECTIONS=12 +### DB specific workspace should not be set, keep for compatible only +### POSTGRES_WORKSPACE=forced_workspace_name + +### PostgreSQL Vector Storage Configuration +### Vector storage type: HNSW, IVFFlat, VCHORDRQ +POSTGRES_VECTOR_INDEX_TYPE=HNSW +POSTGRES_HNSW_M=16 +POSTGRES_HNSW_EF=200 +POSTGRES_IVFFLAT_LISTS=100 +POSTGRES_VCHORDRQ_BUILD_OPTIONS= +POSTGRES_VCHORDRQ_PROBES= +POSTGRES_VCHORDRQ_EPSILON=1.9 + +### PostgreSQL Connection Retry Configuration (Network Robustness) +### Number of retry attempts (1-10, default: 3) +### Initial retry backoff in seconds (0.1-5.0, default: 0.5) +### Maximum retry backoff in seconds (backoff-60.0, default: 5.0) +### Connection pool close timeout in seconds (1.0-30.0, default: 5.0) +# POSTGRES_CONNECTION_RETRIES=3 +# POSTGRES_CONNECTION_RETRY_BACKOFF=0.5 +# POSTGRES_CONNECTION_RETRY_BACKOFF_MAX=5.0 +# POSTGRES_POOL_CLOSE_TIMEOUT=5.0 + +### PostgreSQL SSL Configuration (Optional) +# POSTGRES_SSL_MODE=require +# POSTGRES_SSL_CERT=/path/to/client-cert.pem +# POSTGRES_SSL_KEY=/path/to/client-key.pem +# POSTGRES_SSL_ROOT_CERT=/path/to/ca-cert.pem +# POSTGRES_SSL_CRL=/path/to/crl.pem + +### PostgreSQL Server Settings (for Supabase Supavisor) +# Use this to pass extra options to the PostgreSQL connection string. +# For Supabase, you might need to set it like this: +# POSTGRES_SERVER_SETTINGS="options=reference%3D[project-ref]" + +# Default is 100 set to 0 to disable +# POSTGRES_STATEMENT_CACHE_SIZE=100 + +### Neo4j Configuration +NEO4J_URI=neo4j+s://xxxxxxxx.databases.neo4j.io +NEO4J_USERNAME=neo4j +NEO4J_PASSWORD='your_password' +NEO4J_DATABASE=neo4j +NEO4J_MAX_CONNECTION_POOL_SIZE=100 +NEO4J_CONNECTION_TIMEOUT=30 +NEO4J_CONNECTION_ACQUISITION_TIMEOUT=30 +NEO4J_MAX_TRANSACTION_RETRY_TIME=30 +NEO4J_MAX_CONNECTION_LIFETIME=300 +NEO4J_LIVENESS_CHECK_TIMEOUT=30 +NEO4J_KEEP_ALIVE=true +### DB specific workspace should not be set, keep for compatible only +### NEO4J_WORKSPACE=forced_workspace_name + +### MongoDB Configuration +MONGO_URI=mongodb://root:root@localhost:27017/ +#MONGO_URI=mongodb+srv://xxxx +MONGO_DATABASE=LightRAG +# MONGODB_WORKSPACE=forced_workspace_name + +### Milvus Configuration +MILVUS_URI=http://localhost:19530 +MILVUS_DB_NAME=lightrag +# MILVUS_USER=root +# MILVUS_PASSWORD=your_password +# MILVUS_TOKEN=your_token +### DB specific workspace should not be set, keep for compatible only +### MILVUS_WORKSPACE=forced_workspace_name + +### Qdrant +QDRANT_URL=http://localhost:6333 +# QDRANT_API_KEY=your-api-key +### DB specific workspace should not be set, keep for compatible only +### QDRANT_WORKSPACE=forced_workspace_name + +### Redis +REDIS_URI=redis://localhost:6379 +REDIS_SOCKET_TIMEOUT=30 +REDIS_CONNECT_TIMEOUT=10 +REDIS_MAX_CONNECTIONS=100 +REDIS_RETRY_ATTEMPTS=3 +### DB specific workspace should not be set, keep for compatible only +### REDIS_WORKSPACE=forced_workspace_name + +### Memgraph Configuration +MEMGRAPH_URI=bolt://localhost:7687 +MEMGRAPH_USERNAME= +MEMGRAPH_PASSWORD= +MEMGRAPH_DATABASE=memgraph +### DB specific workspace should not be set, keep for compatible only +### MEMGRAPH_WORKSPACE=forced_workspace_name + +########################################################### +### Langfuse Observability Configuration +### Only works with LLM provided by OpenAI compatible API +### Install with: pip install lightrag-hku[observability] +### Sign up at: https://cloud.langfuse.com or self-host +########################################################### +# LANGFUSE_SECRET_KEY="" +# LANGFUSE_PUBLIC_KEY="" +# LANGFUSE_HOST="https://cloud.langfuse.com" # 或您的自托管实例地址 +# LANGFUSE_ENABLE_TRACE=true + +############################ +### Evaluation Configuration +############################ +### RAGAS evaluation models (used for RAG quality assessment) +### ⚠️ IMPORTANT: Both LLM and Embedding endpoints MUST be OpenAI-compatible +### Default uses OpenAI models for evaluation + +### LLM Configuration for Evaluation +# EVAL_LLM_MODEL=gpt-4o-mini +### API key for LLM evaluation (fallback to OPENAI_API_KEY if not set) +# EVAL_LLM_BINDING_API_KEY=your_api_key +### Custom OpenAI-compatible endpoint for LLM evaluation (optional) +# EVAL_LLM_BINDING_HOST=https://api.openai.com/v1 + +### Embedding Configuration for Evaluation +# EVAL_EMBEDDING_MODEL=text-embedding-3-large +### API key for embeddings (fallback: EVAL_LLM_BINDING_API_KEY -> OPENAI_API_KEY) +# EVAL_EMBEDDING_BINDING_API_KEY=your_embedding_api_key +### Custom OpenAI-compatible endpoint for embeddings (fallback: EVAL_LLM_BINDING_HOST) +# EVAL_EMBEDDING_BINDING_HOST=https://api.openai.com/v1 + +### Performance Tuning +### Number of concurrent test case evaluations +### Lower values reduce API rate limit issues but increase evaluation time +# EVAL_MAX_CONCURRENT=2 +### TOP_K query parameter of LightRAG (default: 10) +### Number of entities or relations retrieved from KG +# EVAL_QUERY_TOP_K=10 +### LLM request retry and timeout settings for evaluation +# EVAL_LLM_MAX_RETRIES=5 +# EVAL_LLM_TIMEOUT=180 diff --git a/jrs/_notes/_0_install_use_Instructions.txt b/jrs/_notes/_0_install_use_Instructions.txt new file mode 100644 index 0000000000..984835d6d8 --- /dev/null +++ b/jrs/_notes/_0_install_use_Instructions.txt @@ -0,0 +1,128 @@ + +############## The commands in this section are used to install the libraries #################### + + +# I happen to be using Ubuntu on WSL. +# To install this fork of LightRAG run the following script: +# jrs/_notes/setup.sh +# Be sure to read and understand the script before using it. + + +# After running the script, do the following to continue setting up the LightRAG library: + +# Add your API keys to the .env file." +# Or for better security, add API keys to your .bashrc file instead." +# No matter where you decide to put your API keys, you will still need a .env file for other required settings" +# If you prefer, you can use the .env file I have in the "LightRAG/jrs/_notes" directory which is working for me." +# Just copy that file to the LightRAG folder." + + +To use LightRAG, run these commands: + 1. cd LightRAG + + 2. source .venv/bin/activate + This activates the virtual environment. + + 3. export PS1='(.venv) \w\$ ' + This creates a better looking prompt that takes up less space on the screen. + + 4. lightrag-server + This starts the server. + + 5. Visit: http://localhost:9621/webui/ + This WebUI is where you interact with the server. + + 6. You can interact directly with LightRAG using the + Python scripts in the jrs directory without the need + for the lightrag server or the WebUI + + +############## The commands is this section are for running the scripts in this repository #################### + + +# To index documents so that the information is available for query, or to run a query, use the following commands at the bash terminal. + +# Navigate to the LightRAG directory. +cd LightRAG + +# Activate the LightRAG virtual environment with the following command. +source .venv/bin/activate + +# Creates a better looking prompt that takes up less space on the screen. +export PS1='(.venv) \w\$ ' + + +# To index the document of interest. +# First index the metadata json file if any with the following script: +jrs/_1_custom_index_01.py +# Open the script in your text editor. +nano jrs/_1_custom_index_01.py + +# Specify the WORKING_DIR where you wish to have the index created. +# This is where all the files for the knowledge graph, vector database, and supporting files will reside. +# Below is an example of what you might specify inside the file. +WORKING_DIR = "/home/js/LightRAG/jrs/work/_0_seheult_work_dir" + +# Specify the json files to be indexed. +# Below is an example of what you might specify inside the file. +# You can specify as many documents as you want in the list. +files_2b_indexed = [ + "/home/js/LightRAG/jrs/work/_0_seheult_metadata/_bNySyEobfY_metadata.json", + "/home/js/LightRAG/jrs/work/_0_seheult_metadata/0m1Qekrfs7w_metadata.json", +] + +# Then run the script with the following command: +python3 jrs/_1_custom_index_01.py + +# Now you have some metadata in your knowledge graph and vector database. +# This is information about the YouTube videos from which transcripts were created. +# This is information like YouTube channel, Video title, posting date, and so on. + + +# Now we want to index the transcripts so that we can see their knowledge graphs and so that we can talk to an a.i. about them. + +# The following script is used for indexing text, images, equations, and tables in your documents: +# jrs/_1_ra_index.py +# To see how to use the script and what arguments must be supplied, run the following command: +python3 jrs/_1_ra_index.py --help + +# Then index the text, images, equations, and tables in your document with the following command: +python3 jrs/_1_ra_index.py + +# Text only documents can be indexed using the LightRAG server and the WebUI. + +# Index the document and it's metadata first in its own individual directory before indexing it with other documents. +# Then examine the index with the following GUI to determine misspellings and duplicate entities with minor spelling differences: +python3 /home/js/LightRAG/jrs/_1_merge_GUI_58.py +# Fix all of the spelling errors in the orignial documents and consolidate the spellings of names and other entities. +# Look for in the _0_sample_data folder for detailed notes on cleaning up dirty data. + +# To query your indexes use the following commands. + +# Navigate to the LightRAG directory. +cd LightRAG + +# Activate the LightRAG virtual environment with the following command. +source .venv/bin/activate + +# Create a better looking prompt +export PS1='(.venv) \w\$ ' + +# If your indexes contain text only, it is best to use the LightRAG Server and the WebUI for running your queries. +# If your indexes contain text only and you need to run the query for several modes (naive, local, global, hybrid, or mix) then use the following script: +python3 jrs/_2_ra_query_text.py --help +python3 jrs/_2_ra_query_text.py + +# If your indexes contain text, equations, images, or tables then the following script is the best tool for the job. +python3 jrs/_2_ra_query_image.py --help +python3 jrs/_2_ra_query_image.py + + +# WSL Networking: +# In WSL, you may need to access the server from Windows. +# To do this, find the WSL IP address with the following bash command: +ip addr show eth0 | grep inet +# Look for an IP like 172.x.x.x. +You can access the WebUI from a Windows browser at http://:9621 (replace with the actual IP). + +Have Fun! diff --git a/jrs/_notes/setup-dev.sh b/jrs/_notes/setup-dev.sh new file mode 100644 index 0000000000..8adbdb1c8b --- /dev/null +++ b/jrs/_notes/setup-dev.sh @@ -0,0 +1,218 @@ +#!/bin/bash + +# This script installs a development version of LightRAG and RAGAnywhere into a directory named LightRAG-Dev. +# This installation will allow modifications to the WebUI + + +# The following pre-migration checklist is for me only. +# It ensures I am giving you my latest code and the latest code from LightRAG too. +# ============================================================================== +# PRE-MIGRATION & SYNC CHECKLIST (Run this in your Dev environment) +# ============================================================================== +# 1. --- Navigate to the directory we wish to sync --- + # cd ~/LightRAG-Dev + +# 2. --- SAVE YOUR WEBUI WORK --- +# Ensure your frontend changes are built and staged + # cd lightrag_webui + # bun run build + # cd .. + +# 3. --- COMMIT LOCAL CHANGES --- + # pre-commit run --all-files + # uv pip install -e . # ← KEEP OUR PATCH ACTIVE + # git add . + # git commit -m "Checkpoint: Save WebUI modifications and built assets" + +# 4. --- CREATE CHECKPOINT FOR DISASTER ROLLBACK -- + # git branch backup/pre-migration-2026-02-02 !!! USE THE CORRECT DATE !!! + # git push origin backup/pre-migration-2026-02-02 + +# 5. --- SYNC WITH UPSTREAM (HKUDS/LightRAG) --- + # git fetch upstream + # git merge upstream/main -X ours + # uv pip install -e . # ← KEEP OUR PATCH ACTIVE +# +# --> Note: The -X ours flag above tells Git: "If we both changed the exact same line, keep my version." +# --> Note: If conflicts occur, prioritize keeping your new WebUI files +# but accept upstream fixes for the core RAG logic. + +# 6. --- UPDATE DEPENDENCY LOCK FILES --- +# Ensure both Python and Frontend dependencies are healthy after the merge + # uv lock + # cd lightrag_webui + # bun install + # cd .. + +# 7. --- OPTIONAL AND UNTESTED: PURGE GHOST ARTIFACTS --- +# Clear Python/uv cache + # uv cache clean + +# Navigate to WebUI and clear Vite/Bun cache + # cd lightrag_webui + +# Remove the local 'dist' and Vite's internal cache + # rm -rf dist node_modules/.vite +# (Optional) Re-run build to ensure everything is fresh + # bun run build + # cd .. + +# 8. --- FINAL COMMIT --- + # git add . + # git commit -m "Final sync: Integrated upstream changes & updated lock files" +# --> (Only required if Step 4 had conflicts or Step 5 updated locks) + +# 9. --- PUSH TO YOUR FORK --- + # git push origin main + +# 10. --- VERIFY --- + # git status +# --> Should see: "Your branch is up to date with 'origin/main'" +# ============================================================================== + + + + +# ============================================================================== +# INCASE OF DISASTER WITH THE ABOVE SYNC OPERATION, RECOVER WITH: +# git checkout main +# git reset --hard backup/pre-migration-2026-02-02 +# ============================================================================== + + + +#*******************Start of Notes********************* +# Move this file to your home directory before running it. +# This will cause the LightRAG-Dev directory to be created in your home directory. + +# !!!!!!!!!!!!!!!! Important! !!!!!!!!!!!!!!!!! +# If you already have a directory named LightRAG-Dev then change the +# REPO_DIR variable in this script to something other than LightRAG-Dev so that +# your orgininal work will not be overwritten. + +# This script here that you are reading will not run if it is not marked by the OS as an executable file. +# Make this file executable with the following bash command in your terminal window: +# chmod +x setup-dev.sh + +# Then run the script with the following bash command in your terminal window: +# ./setup-dev.sh + +# LightRAG and RAGAnywhere require environment variables in order to run. (API keys, etc.) +# Ensure API keys are defined in a .env file in the LightRAG directory. +# For better security, define API keys in the .bashrc file in you home directory rather than in your .env file. +# The following line should be in your .env file, or for better security in your .bashrc file. +#export OPENAI_API_KEY="My_API_Key" +#*******************End of Notes********************* + + + + + +#*******************Start of Script****************** +set -e + +echo "🚀 Starting LightRAG WebUI Development Setup..." + +# 1. Install uv (Python package manager) +if ! command -v uv &> /dev/null +then + echo "📦 Installing uv..." + curl -LsSf https://astral.sh/uv/install.sh | sh + source $HOME/.cargo/env +else + echo "✅ uv is already installed." +fi + +# 2. Install bun (Required by LightRAG for Web UI) +if ! command -v bun &> /dev/null +then + echo "📦 Installing bun..." + curl -fsSL https://bun.sh/install | bash + export PATH="$HOME/.bun/bin:$PATH" +else + echo "✅ bun is already installed." +fi + +# 3. Clone the repository +REPO_URL="https://github.com/johnshearing/LightRAG.git" +REPO_DIR="LightRAG-Dev" + +if [ ! -d "$REPO_DIR" ]; then + echo "📂 Cloning your LightRAG fork for development..." + git clone "$REPO_URL" "$REPO_DIR" + cd "$REPO_DIR" + git remote add upstream https://github.com/HKUDS/LightRAG.git +else + echo "🏠 Directory $REPO_DIR already exists. Updating..." + cd "$REPO_DIR" + git pull origin main +fi + +# 4. Create Virtual Environment and Sync for DEVELOPMENT +echo "⚙️ Setting up Python environment in EDITABLE mode..." +# Using --editable ensures changes to the source code are live +uv sync --all-extras --editable + +echo "📦 Installing additional RAGAnything components..." +uv pip install raganything + +# 5. Build the Web UI +echo "🌐 Setting up Web UI for development..." +if [ -d "lightrag_webui" ]; then + cd lightrag_webui + echo "Installing frontend dependencies with bun..." + bun install --frozen-lockfile + + echo "Performing initial frontend build..." + # This places the initial assets into lightrag/api/webui/ + bun run build + cd .. + echo "✅ Web UI initialized." +else + echo "❌ Error: lightrag_webui directory not found!" + exit 1 +fi + +echo "-----------------------------------------------" +echo "🎉 Setup Complete!" + +# 6. Check for .env file +if [ ! -f ".env" ]; then + echo "⚠️ WARNING: No .env file detected." + echo " Please copy your working .env from your backup or jrs/_notes folder." + echo " Then add your API keys to the .env file." + echo " Or for better security, add API keys to your .bashrc file instead." + echo " No matter where you decide to put your API keys, you will still need a .env file for other required settings" +fi + + +# 7. Verification and Instructions +cat << EOF + +To DEVELOP and MODIFY the WebUI, follow these steps: + +1. ACTIVATE ENVIRONMENT: + cd $REPO_DIR + source .venv/bin/activate + +2. CREATE A PROMPT THAT TAKES UP LESS SPACE. + export PS1='(.venv) \w\$ ' + +2. START THE BACKEND: + lightrag-server + +3. LIVE UI DEVELOPMENT (Hot Reloading): + In a NEW terminal: + cd $REPO_DIR/lightrag_webui + bun run dev + # This will give you a link (usually http://localhost:5173) + # that updates instantly as you change code. + +4. FINALIZING UI CHANGES: + When your changes are ready, run: + cd $REPO_DIR/lightrag_webui + bun run build + # This updates the files served by the main lightrag-server. + +----------------------------------------------- +EOF diff --git a/jrs/_notes/setup.sh b/jrs/_notes/setup.sh new file mode 100644 index 0000000000..341abb0734 --- /dev/null +++ b/jrs/_notes/setup.sh @@ -0,0 +1,154 @@ +#!/bin/bash + +# This script installs LightRAG and RAGAnywhere into a directory named LightRAG + + +# The following pre-migration checklist is for me only. +# It ensures I am giving you my latest code and the latest code from LightRAG too. +# ============================================================================== +# PRE-MIGRATION CHECKLIST (Run these on the computer with original source!) +# ============================================================================== +# 1. cd ~/LightRAG +# 2. git add . +# 3. git commit -m "Checkpoint: Save my work before syncing" +# 4. git fetch upstream +# 5. git merge upstream/main +# --> If conflicts occur, fix them in .gitignore / pyproject.toml +# 6. uv lock +# --> This ensures the lock file is healthy after the merge +# 7. git add . +# 8. git commit -m "Final sync: Integrated upstream changes" +# --> (Note: Only required if Step 5 had conflicts) +# 9. git push origin main +# 10. git status +# --> I should see: "Your branch is up to date with 'origin/main'" and "nothing to commit, working tree clean" +# ============================================================================== + + + +#*******************Start of Notes********************* +# Move this file to your home directory before running it. +# This will cause the LightRAG directory to be created in your home directory. + +# !!!!!!!!!!!!!!!! Important! !!!!!!!!!!!!!!!!! +# If you already have a directory named LightRAG then change the +# REPO_DIR variable in this script to something other than LightRAG so that +# your orgininal work will not be overwritten. + +# This script here that you are reading will not run if it is not marked by the OS as an executable file. +# Make this file executable with the following bash command in your terminal window: +# chmod +x setup.sh + +# Then run the script with the following bash command in your terminal window: +# ./setup.sh + +# LightRAG and RAGAnywhere require environment variables in order to run. (API keys, etc.) +# Ensure API keys are defined in a .env file in the LightRAG directory. +# For better security, define API keys in the .bashrc file in you home directory rather than in your .env file. +# The following line should be in your .env file, or for better security in your .bashrc file. +#export OPENAI_API_KEY="My_API_Key" +#*******************End of Notes********************* + + + + + +#*******************Start of Script****************** +set -e + +echo "🚀 Starting RAGAnywhere Environment Setup..." + +# 1. Install uv (Python package manager) +if ! command -v uv &> /dev/null +then + echo "📦 Installing uv..." + curl -LsSf https://astral.sh/uv/install.sh | sh + source $HOME/.cargo/env +else + echo "✅ uv is already installed." +fi + +# 2. Install bun (Needed for Web UI) +if ! command -v bun &> /dev/null +then + echo "📦 Installing bun..." + curl -fsSL https://bun.sh/install | bash + # For a script, we source the specific bin path rather than the whole bashrc + export PATH="$HOME/.bun/bin:$PATH" +else + echo "✅ bun is already installed." +fi + +# 3. Clone the repository +REPO_URL="https://github.com/johnshearing/LightRAG.git" +REPO_DIR="LightRAG" + +if [ ! -d "$REPO_DIR" ]; then + echo "📂 Cloning your LightRAG fork..." + git clone "$REPO_URL" + cd "$REPO_DIR" +else + echo "🏠 Directory $REPO_DIR already exists. Updating..." + cd "$REPO_DIR" + git pull origin main +fi + +# 4. Create Virtual Environment and Sync Dependencies +echo "⚙️ Syncing Python dependencies with uv..." +uv sync --all-extras + +echo "📦 Installing additional RAGAnything components..." +uv pip install raganything + +# 5. Build the Web UI +echo "🌐 Building Web UI..." +if [ -d "lightrag_webui" ]; then + cd lightrag_webui + bun install --frozen-lockfile + bun run build + cd .. + echo "✅ Web UI built successfully." +else + echo "❌ Error: lightrag_webui directory not found!" + exit 1 +fi + +echo "-----------------------------------------------" +echo "🎉 Setup Complete!" + +# 6. Check for .env file +if [ ! -f ".env" ]; then + echo "⚠️ WARNING: No .env file detected." + echo " Action: cp env.example .env" + echo " Then add your API keys to the .env file." + echo " Or for better security, add API keys to your .bashrc file instead." + echo " No matter where you decide to put your API keys, you will still need a .env file for other required settings" + echo " If you prefer, you can use the .env file I have in the "LightRAG/jrs/_notes" directory which is working for me." + echo " Just copy that file to the LightRAG folder." +fi + + +# 7. Final Verification and Instructions +cat << EOF + + +To use LightRAG, run these commands: + 1. cd LightRAG + + 2. source .venv/bin/activate + This activates the virtual environment. + + 3. export PS1='(.venv) \w\$ ' + This creates a better looking prompt that takes up less space on the screen. + + 4. lightrag-server + This starts the server. + + 5. Visit: http://localhost:9621/webui/ + This WebUI is where you interact with the server. + + 6. You can interact directly with LightRAG using the + Python scripts in the jrs directory without the need + for the lightrag server or the WebUI +----------------------------------------------- +EOF diff --git a/jrs/archive/_1_custom_cat_generate_01.py b/jrs/archive/_1_custom_cat_generate_01.py new file mode 100644 index 0000000000..fbcf87e380 --- /dev/null +++ b/jrs/archive/_1_custom_cat_generate_01.py @@ -0,0 +1,134 @@ +# Takes lines like the following from prompt.py and creates the json required to ingest categories into a LightRAG index using _1_custom_index_01.py +# "category_hub is the hub entity with an entity_type of category_hub and which shares a relationship with every entity that has the entity_type of category.", +# "place is an entity with an entity_type of category which describes any geographic location.", +# "event is an entity with an entity_type of category which describes a particular situation at a specific time.", +# "anatomy is an entity with an entity_type of category which describes any part of a living organism.", + + +import json + + +def generate_category_json(input_filepath, output_filepath): + """ + Generates a JSON file containing chunks, entities, and relationships + based on a list of category descriptions. + + Args: + input_filepath (str): The path to the input text file. + output_filepath (str): The path to the output JSON file. + """ + chunks_content = [] + entities = [] + relationships = [] + + # Hardcode category_hub's correct entry as it's a special case + category_hub_entity = { + "entity_name": "category_hub", + "entity_type": "category_hub", + "description": "category_hub is the hub entity with an entity_type of category_hub and which shares a relationship with every entity that has the entity_type of category.", + "source_id": "category_data.json", + } + + # Add category_hub's entity first + entities.append(category_hub_entity) + chunks_content.append(category_hub_entity["description"]) + + with open(input_filepath, "r") as f: + lines = f.readlines() + + for line in lines: + line = line.strip() + if not line: + continue + + entity_name = "" + entity_type = "" + description = "" + + # Skip the original category_hub line, as we've hardcoded it + if "category_hub (the entity_name is category_hub)" in line: + continue + + # Remove leading/trailing ".,", "," or "." and quotes if they exist on the raw line + line = line.strip().strip(",.").strip('"') + + # General parsing for other entities + parts = line.split(" is an entity with an entity_type of ", 1) + if len(parts) < 2: + # This line doesn't conform to the expected "is an entity with an entity_type of" format, skip it + continue + + entity_name = parts[0].strip().strip('"') # Strip quotes from entity_name + type_and_desc_part = parts[1] + + # Determine the correct description split keyword + description_split_keyword = "" + if " which describes " in type_and_desc_part: + description_split_keyword = " which describes " + elif " which describe " in type_and_desc_part: + description_split_keyword = " which describe " + + # Only proceed if a valid keyword was found + if not description_split_keyword: + # If neither keyword is found, this line doesn't fit the pattern for description extraction + continue + + type_desc_split = type_and_desc_part.split(description_split_keyword, 1) + + if len(type_desc_split) < 2: + # This should ideally not happen if description_split_keyword was found, + # but as a safeguard. + continue + + entity_type = type_desc_split[0].strip().strip('"') + description_suffix = ( + type_desc_split[1].strip().strip('",.') + ) # Strip quotes and punctuation from suffix + + # Reconstruct the description correctly without extra punctuation from source + # Use the actual split keyword in the reconstructed description for accuracy + description = f"{entity_name} is an entity with an entity_type of {entity_type}{description_split_keyword}{description_suffix}." + + # Add to entities list + entity_entry = { + "entity_name": entity_name, + "entity_type": entity_type, + "description": description, + "source_id": "category_data.json", + } + entities.append(entity_entry) + + # Add to chunks content (only for included entities) + chunks_content.append(description) + + # Add to relationships if it's a category entity and not category_hub + if entity_type == "category" and entity_name != "category_hub": + relationship_entry = { + "src_id": "category_hub", + "tgt_id": entity_name, + "description": f"{entity_name} is an element of the set category_hub", + "keywords": f"{entity_name}, element of, category_hub", + "weight": 7.0, + "source_id": "category_data.json", + } + relationships.append(relationship_entry) + + # Construct the final JSON structure + output_data = { + "chunks": [ + {"content": "\n".join(chunks_content), "source_id": "category_data.json"} + ], + "entities": entities, + "relationships": relationships, + } + + # Write the JSON to the output file + with open(output_filepath, "w") as f: + json.dump(output_data, f, indent=4) + + +if __name__ == "__main__": + input_file = "test_categories.txt" + output_file = "test_categories.json" + generate_category_json(input_file, output_file) + print(f"JSON data successfully generated in '{output_file}'") diff --git a/jrs/archive/_1_custom_entity_delete_01.py b/jrs/archive/_1_custom_entity_delete_01.py new file mode 100644 index 0000000000..932b72b6a6 --- /dev/null +++ b/jrs/archive/_1_custom_entity_delete_01.py @@ -0,0 +1,133 @@ +import os +import asyncio +import logging +import logging.config +from lightrag import LightRAG +from lightrag.llm.openai import gpt_4o_mini_complete +from lightrag.kg.shared_storage import initialize_pipeline_status +from lightrag.utils import logger, set_verbose_debug, EmbeddingFunc +from llama_index.embeddings.openai import OpenAIEmbedding + +# Configuration +WORKING_DIR = "_0_jack_work_dir_01" +EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text-embedding-3-large") +EMBEDDING_DIM = int(os.getenv("EMBEDDING_DIM", 3072)) +API_KEY = os.getenv("EMBEDDING_BINDING_API_KEY") +MAX_TOKEN_SIZE = int(os.getenv("MAX_TOKEN_SIZE", 8192)) + + +def configure_logging(): + """Configure logging with console and rotating file handlers.""" + for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]: + logger_instance = logging.getLogger(logger_name) + logger_instance.handlers = [] + logger_instance.filters = [] + log_dir = os.getenv("LOG_DIR", os.getcwd()) + log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag_index.log")) + print(f"\nLightRAG index log file: {log_file_path}\n") + os.makedirs(os.path.dirname(log_dir), exist_ok=True) + log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) + log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) + logging.config.dictConfig( + { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": {"format": "%(levelname)s: %(message)s"}, + "detailed": { + "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + }, + }, + "handlers": { + "console": { + "formatter": "default", + "class": "logging.StreamHandler", + "stream": "ext://sys.stderr", + }, + "file": { + "formatter": "detailed", + "class": "logging.handlers.RotatingFileHandler", + "filename": log_file_path, + "maxBytes": log_max_bytes, + "backupCount": log_backup_count, + "encoding": "utf-8", + }, + }, + "loggers": { + "lightrag": { + "handlers": ["console", "file"], + "level": "INFO", + "propagate": False, + }, + }, + } + ) + logger.setLevel(logging.INFO) + set_verbose_debug(os.getenv("VERBOSE_DEBUG", "true").lower() == "true") + + +if not os.path.exists(WORKING_DIR): + os.makedirs(WORKING_DIR) + + +async def initialize_rag(): + """Initialize LightRAG with custom embedding function.""" + print("Initializing LightRAG for indexing...") + + # Initialize embedding model + embed_model = OpenAIEmbedding( + model=EMBEDDING_MODEL, api_key=API_KEY, dimensions=EMBEDDING_DIM + ) + + # Define async embedding function + async def async_embedding_func(texts): + return embed_model.get_text_embedding_batch(texts) + + # Define embedding function + embedding_func = EmbeddingFunc( + embedding_dim=EMBEDDING_DIM, + max_token_size=MAX_TOKEN_SIZE, + func=async_embedding_func, + ) + + # Initialize LightRAG + rag = LightRAG( + working_dir=WORKING_DIR, + embedding_func=embedding_func, + llm_model_func=gpt_4o_mini_complete, + ) + + await rag.initialize_storages() + await initialize_pipeline_status() + return rag + + +async def main(): + """Main function to index documents.""" + rag = None + try: + if not os.getenv("OPENAI_API_KEY") and not API_KEY: + raise ValueError( + "OPENAI_API_KEY or EMBEDDING_BINDING_API_KEY environment variable not set" + ) + rag = await initialize_rag() + + # Delete By Entity + await rag.adelete_by_entity("Circadian Practice") + + print("Entity Deleted") + except Exception as e: + print(f"An error occurred: {e}") + import traceback + + traceback.print_exc() + finally: + if rag: + print("Finalizing storages...") + await rag.finalize_storages() + + +if __name__ == "__main__": + configure_logging() + asyncio.run(main()) + print("\nIndexing Done!") diff --git a/jrs/archive/_1_custom_entity_merge_01.py b/jrs/archive/_1_custom_entity_merge_01.py new file mode 100644 index 0000000000..094d6fe57c --- /dev/null +++ b/jrs/archive/_1_custom_entity_merge_01.py @@ -0,0 +1,140 @@ +import os +import asyncio +import logging +import logging.config +from lightrag import LightRAG +from lightrag.llm.openai import gpt_4o_mini_complete +from lightrag.kg.shared_storage import initialize_pipeline_status +from lightrag.utils import logger, set_verbose_debug, EmbeddingFunc +from llama_index.embeddings.openai import OpenAIEmbedding + +# Configuration +WORKING_DIR = "_0_jack_work_dir_01" +EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text-embedding-3-large") +EMBEDDING_DIM = int(os.getenv("EMBEDDING_DIM", 3072)) +API_KEY = os.getenv("EMBEDDING_BINDING_API_KEY") +MAX_TOKEN_SIZE = int(os.getenv("MAX_TOKEN_SIZE", 8192)) + + +def configure_logging(): + """Configure logging with console and rotating file handlers.""" + for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]: + logger_instance = logging.getLogger(logger_name) + logger_instance.handlers = [] + logger_instance.filters = [] + log_dir = os.getenv("LOG_DIR", os.getcwd()) + log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag_index.log")) + print(f"\nLightRAG index log file: {log_file_path}\n") + os.makedirs(os.path.dirname(log_dir), exist_ok=True) + log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) + log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) + logging.config.dictConfig( + { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": {"format": "%(levelname)s: %(message)s"}, + "detailed": { + "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + }, + }, + "handlers": { + "console": { + "formatter": "default", + "class": "logging.StreamHandler", + "stream": "ext://sys.stderr", + }, + "file": { + "formatter": "detailed", + "class": "logging.handlers.RotatingFileHandler", + "filename": log_file_path, + "maxBytes": log_max_bytes, + "backupCount": log_backup_count, + "encoding": "utf-8", + }, + }, + "loggers": { + "lightrag": { + "handlers": ["console", "file"], + "level": "INFO", + "propagate": False, + }, + }, + } + ) + logger.setLevel(logging.INFO) + set_verbose_debug(os.getenv("VERBOSE_DEBUG", "true").lower() == "true") + + +if not os.path.exists(WORKING_DIR): + os.makedirs(WORKING_DIR) + + +async def initialize_rag(): + """Initialize LightRAG with custom embedding function.""" + print("Initializing LightRAG for indexing...") + + # Initialize embedding model + embed_model = OpenAIEmbedding( + model=EMBEDDING_MODEL, api_key=API_KEY, dimensions=EMBEDDING_DIM + ) + + # Define async embedding function + async def async_embedding_func(texts): + return embed_model.get_text_embedding_batch(texts) + + # Define embedding function + embedding_func = EmbeddingFunc( + embedding_dim=EMBEDDING_DIM, + max_token_size=MAX_TOKEN_SIZE, + func=async_embedding_func, + ) + + # Initialize LightRAG + rag = LightRAG( + working_dir=WORKING_DIR, + embedding_func=embedding_func, + llm_model_func=gpt_4o_mini_complete, + ) + + await rag.initialize_storages() + await initialize_pipeline_status() + return rag + + +async def main(): + """Main function to index documents.""" + rag = None + try: + if not os.getenv("OPENAI_API_KEY") and not API_KEY: + raise ValueError( + "OPENAI_API_KEY or EMBEDDING_BINDING_API_KEY environment variable not set" + ) + rag = await initialize_rag() + + # Merge company entities with both strategy and custom data + await rag.amerge_entities( + source_entities=["Dr. Max Gulhane", "Max"], + target_entity="Dr. Max Gulhane", + merge_strategy={"description": "join_unique", "source_id": "join_unique"}, + target_entity_data={ + "entity_type": "person", + }, + ) + + print("Merged") + except Exception as e: + print(f"An error occurred: {e}") + import traceback + + traceback.print_exc() + finally: + if rag: + print("Finalizing storages...") + await rag.finalize_storages() + + +if __name__ == "__main__": + configure_logging() + asyncio.run(main()) + print("\nIndexing Done!") diff --git a/jrs/archive/_1_custom_export_01.py b/jrs/archive/_1_custom_export_01.py new file mode 100644 index 0000000000..2645712390 --- /dev/null +++ b/jrs/archive/_1_custom_export_01.py @@ -0,0 +1,134 @@ +import os +import asyncio +import logging +import logging.config +from lightrag import LightRAG +from lightrag.llm.openai import gpt_4o_mini_complete +from lightrag.kg.shared_storage import initialize_pipeline_status +from lightrag.utils import logger, set_verbose_debug, EmbeddingFunc +from llama_index.embeddings.openai import OpenAIEmbedding + +# Configuration +WORKING_DIR = "_0_jack_work_dir_02" +EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text-embedding-3-large") +EMBEDDING_DIM = int(os.getenv("EMBEDDING_DIM", 3072)) +API_KEY = os.getenv("EMBEDDING_BINDING_API_KEY") +MAX_TOKEN_SIZE = int(os.getenv("MAX_TOKEN_SIZE", 8192)) + + +def configure_logging(): + """Configure logging with console and rotating file handlers.""" + for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]: + logger_instance = logging.getLogger(logger_name) + logger_instance.handlers = [] + logger_instance.filters = [] + log_dir = os.getenv("LOG_DIR", os.getcwd()) + log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag_index.log")) + print(f"\nLightRAG index log file: {log_file_path}\n") + os.makedirs(os.path.dirname(log_dir), exist_ok=True) + log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) + log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) + logging.config.dictConfig( + { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": {"format": "%(levelname)s: %(message)s"}, + "detailed": { + "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + }, + }, + "handlers": { + "console": { + "formatter": "default", + "class": "logging.StreamHandler", + "stream": "ext://sys.stderr", + }, + "file": { + "formatter": "detailed", + "class": "logging.handlers.RotatingFileHandler", + "filename": log_file_path, + "maxBytes": log_max_bytes, + "backupCount": log_backup_count, + "encoding": "utf-8", + }, + }, + "loggers": { + "lightrag": { + "handlers": ["console", "file"], + "level": "INFO", + "propagate": False, + }, + }, + } + ) + logger.setLevel(logging.INFO) + set_verbose_debug(os.getenv("VERBOSE_DEBUG", "true").lower() == "true") + + +if not os.path.exists(WORKING_DIR): + os.makedirs(WORKING_DIR) + + +async def initialize_rag(): + """Initialize LightRAG with custom embedding function.""" + print("Initializing LightRAG for indexing...") + + # Initialize embedding model + embed_model = OpenAIEmbedding( + model=EMBEDDING_MODEL, api_key=API_KEY, dimensions=EMBEDDING_DIM + ) + + # Define async embedding function + async def async_embedding_func(texts): + return embed_model.get_text_embedding_batch(texts) + + # Define embedding function + embedding_func = EmbeddingFunc( + embedding_dim=EMBEDDING_DIM, + max_token_size=MAX_TOKEN_SIZE, + func=async_embedding_func, + ) + + # Initialize LightRAG + rag = LightRAG( + working_dir=WORKING_DIR, + embedding_func=embedding_func, + llm_model_func=gpt_4o_mini_complete, + ) + + await rag.initialize_storages() + await initialize_pipeline_status() + return rag + + +async def main(): + """Main function to index documents.""" + rag = None + try: + if not os.getenv("OPENAI_API_KEY") and not API_KEY: + raise ValueError( + "OPENAI_API_KEY or EMBEDDING_BINDING_API_KEY environment variable not set" + ) + rag = await initialize_rag() + + # Delete By Relation + # Export data in markdown format + await rag.aexport_data("graph_data.md", file_format="md") + + print("Exported") + except Exception as e: + print(f"An error occurred: {e}") + import traceback + + traceback.print_exc() + finally: + if rag: + print("Finalizing storages...") + await rag.finalize_storages() + + +if __name__ == "__main__": + configure_logging() + asyncio.run(main()) + print("\nIndexing Done!") diff --git a/jrs/archive/_1_custom_list_json_files.py b/jrs/archive/_1_custom_list_json_files.py new file mode 100644 index 0000000000..eac095963b --- /dev/null +++ b/jrs/archive/_1_custom_list_json_files.py @@ -0,0 +1,26 @@ +# Creates a text file containing a list of all json files in a specified directory. + + +import os + +# Set the directory to search and the output file name +directory = "_0_seheult_data" +output_file = "_0_list_of_json_files.txt" + +# Get list of .json files with full paths +json_files = [ + os.path.join(os.path.abspath(directory), f) + for f in os.listdir(directory) + if f.endswith(".json") +] + +# Write to the output file with no trailing comma on the last line +with open(output_file, "w") as f: + for i, filepath in enumerate(json_files): + quoted = f'"{filepath}"' + if i < len(json_files) - 1: + f.write(quoted + ",\n") + else: + f.write(quoted + "\n") # last line, no comma + +print(f"Finished writing {len(json_files)} .json file(s) to {output_file}") diff --git a/jrs/archive/_1_custom_relation_delete_01.py b/jrs/archive/_1_custom_relation_delete_01.py new file mode 100644 index 0000000000..bf78d0af4b --- /dev/null +++ b/jrs/archive/_1_custom_relation_delete_01.py @@ -0,0 +1,135 @@ +import os +import asyncio +import logging +import logging.config +from lightrag import LightRAG +from lightrag.llm.openai import gpt_4o_mini_complete +from lightrag.kg.shared_storage import initialize_pipeline_status +from lightrag.utils import logger, set_verbose_debug, EmbeddingFunc +from llama_index.embeddings.openai import OpenAIEmbedding + +# Configuration +WORKING_DIR = "_0_jack_work_dir_01" +EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text-embedding-3-large") +EMBEDDING_DIM = int(os.getenv("EMBEDDING_DIM", 3072)) +API_KEY = os.getenv("EMBEDDING_BINDING_API_KEY") +MAX_TOKEN_SIZE = int(os.getenv("MAX_TOKEN_SIZE", 8192)) + + +def configure_logging(): + """Configure logging with console and rotating file handlers.""" + for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]: + logger_instance = logging.getLogger(logger_name) + logger_instance.handlers = [] + logger_instance.filters = [] + log_dir = os.getenv("LOG_DIR", os.getcwd()) + log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag_index.log")) + print(f"\nLightRAG index log file: {log_file_path}\n") + os.makedirs(os.path.dirname(log_dir), exist_ok=True) + log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) + log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) + logging.config.dictConfig( + { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": {"format": "%(levelname)s: %(message)s"}, + "detailed": { + "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + }, + }, + "handlers": { + "console": { + "formatter": "default", + "class": "logging.StreamHandler", + "stream": "ext://sys.stderr", + }, + "file": { + "formatter": "detailed", + "class": "logging.handlers.RotatingFileHandler", + "filename": log_file_path, + "maxBytes": log_max_bytes, + "backupCount": log_backup_count, + "encoding": "utf-8", + }, + }, + "loggers": { + "lightrag": { + "handlers": ["console", "file"], + "level": "INFO", + "propagate": False, + }, + }, + } + ) + logger.setLevel(logging.INFO) + set_verbose_debug(os.getenv("VERBOSE_DEBUG", "true").lower() == "true") + + +if not os.path.exists(WORKING_DIR): + os.makedirs(WORKING_DIR) + + +async def initialize_rag(): + """Initialize LightRAG with custom embedding function.""" + print("Initializing LightRAG for indexing...") + + # Initialize embedding model + embed_model = OpenAIEmbedding( + model=EMBEDDING_MODEL, api_key=API_KEY, dimensions=EMBEDDING_DIM + ) + + # Define async embedding function + async def async_embedding_func(texts): + return embed_model.get_text_embedding_batch(texts) + + # Define embedding function + embedding_func = EmbeddingFunc( + embedding_dim=EMBEDDING_DIM, + max_token_size=MAX_TOKEN_SIZE, + func=async_embedding_func, + ) + + # Initialize LightRAG + rag = LightRAG( + working_dir=WORKING_DIR, + embedding_func=embedding_func, + llm_model_func=gpt_4o_mini_complete, + ) + + await rag.initialize_storages() + await initialize_pipeline_status() + return rag + + +async def main(): + """Main function to index documents.""" + rag = None + try: + if not os.getenv("OPENAI_API_KEY") and not API_KEY: + raise ValueError( + "OPENAI_API_KEY or EMBEDDING_BINDING_API_KEY environment variable not set" + ) + rag = await initialize_rag() + + # Delete By Relation + await rag.adelete_by_relation( + source_entities=["Dr. Max Gulhane", "Max"], target_entity="Dr. Max Gulhane" + ) + + print("Deleted") + except Exception as e: + print(f"An error occurred: {e}") + import traceback + + traceback.print_exc() + finally: + if rag: + print("Finalizing storages...") + await rag.finalize_storages() + + +if __name__ == "__main__": + configure_logging() + asyncio.run(main()) + print("\nIndexing Done!") diff --git a/jrs/archive/_1_doc_index_01.py b/jrs/archive/_1_doc_index_01.py new file mode 100644 index 0000000000..7d682cfb0a --- /dev/null +++ b/jrs/archive/_1_doc_index_01.py @@ -0,0 +1,42 @@ +import sys +import requests +import json + +# Configuration +API_URL = "http://127.0.0.1:9621" # Default LightRAG API port + +# Change to your document path +INPUT_DIR = "/home/js/LightRAG/jrs/work/mcfadden/mcfadden_data" + + +def trigger_directory_scan(directory_path): + """ + Calls the LightRAG API to scan a directory for new documents. + """ + endpoint = f"{API_URL}/documents/scan" + payload = {"input_dir": directory_path} + + print(f"Sending scan request for: {directory_path}...") + + try: + response = requests.post(endpoint, json=payload) + response.raise_for_status() # Raises error for 4xx/5xx responses + + data = response.json() + print("Successfully triggered scan.") + print(f"Server Response: {json.dumps(data, indent=2)}") + + # Check if the server returned status information + if data.get("status") == "success": + print("\nIndexing has started in the background.") + print("You can monitor progress in your LightRAG terminal/logs.") + + except requests.exceptions.RequestException as e: + print(f"Error connecting to LightRAG server: {e}") + + +if __name__ == "__main__": + # If you provide a path in the command line, use it. + # Otherwise, use the default hard-coded path. + target_dir = sys.argv[1] if len(sys.argv) > 1 else INPUT_DIR + trigger_directory_scan(target_dir) diff --git a/jrs/archive/_1_schematic_to_lightrag.py b/jrs/archive/_1_schematic_to_lightrag.py new file mode 100644 index 0000000000..7757ceeb0c --- /dev/null +++ b/jrs/archive/_1_schematic_to_lightrag.py @@ -0,0 +1,97 @@ +import os +import base64 +import json +import fitz +from openai import OpenAI +from PIL import Image + +API_KEY = os.getenv("OPENAI_API_KEY") +client = OpenAI(api_key=API_KEY) + + +def encode_image(image_path): + with open(image_path, "rb") as f: + return base64.b64encode(f.read()).decode("utf-8") + + +def process_schematic_to_json(pdf_path, output_json="circuit_logic.json"): + doc = fitz.open(pdf_path) + all_extracted_data = [] + + for page_num in range(len(doc)): + print(f"Processing Page {page_num + 1}...") + page = doc.load_page(page_num) + pix = page.get_pixmap(matrix=fitz.Matrix(5, 5)) # Slightly higher zoom (5x) + + full_img_path = f"temp_full_p{page_num}.png" + pix.save(full_img_path) + + # Open with Pillow for tiling + with Image.open(full_img_path) as img: + w, h = img.size + # Define 4 overlapping quadrants (5% overlap to catch wires on the seams) + mid_w, mid_h = w // 2, h // 2 + overlap = int(w * 0.05) + + tiles = [ + ("Top-Left", (0, 0, mid_w + overlap, mid_h + overlap)), + ("Top-Right", (mid_w - overlap, 0, w, mid_h + overlap)), + ("Bottom-Left", (0, mid_h - overlap, mid_w + overlap, h)), + ("Bottom-Right", (mid_w - overlap, mid_h - overlap, w, h)), + ] + + page_connections = [] + + for tile_name, box in tiles: + print(f" Scanning {tile_name}...") + tile_path = f"temp_tile_{tile_name}.png" + img.crop(box).save(tile_path) + + response = client.chat.completions.create( + model="gpt-4o", + max_tokens=4096, + messages=[ + { + "role": "system", + "content": "You are a document digitizer. Record every text label and connecting line found in this section.", + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "EXHAUSTIVELY list every line connection in this image section. Use JSON: source_component, source_terminal, connection_type, wire_label, target_component, target_terminal.", + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{encode_image(tile_path)}" + }, + }, + ], + }, + ], + response_format={"type": "json_object"}, + ) + + content = response.choices[0].message.content + if content: + data = json.loads(content) + page_connections.extend(data.get("connections", [])) + + os.remove(tile_path) + + all_extracted_data.append( + {"page": page_num + 1, "data": {"connections": page_connections}} + ) + print( + f"Total connections found for Page {page_num + 1}: {len(page_connections)}" + ) + os.remove(full_img_path) + + with open(output_json, "w") as f: + json.dump(all_extracted_data, f, indent=2) + + +if __name__ == "__main__": + process_schematic_to_json("jrs/work/mod_linx/mod_linx_data/PS10115MLC2-2.pdf") diff --git a/jrs/archive/_2_index_query_01.py b/jrs/archive/_2_index_query_01.py new file mode 100644 index 0000000000..09ccb676de --- /dev/null +++ b/jrs/archive/_2_index_query_01.py @@ -0,0 +1,167 @@ +import os +import asyncio +import logging +import logging.config +from lightrag import LightRAG, QueryParam +from lightrag.llm.openai import gpt_4o_mini_complete +from lightrag.kg.shared_storage import initialize_pipeline_status +from lightrag.utils import logger, set_verbose_debug, EmbeddingFunc +from llama_index.embeddings.openai import OpenAIEmbedding + +# Load environment variables from .env file +# from dotenv import load_dotenv +# load_dotenv() + +# Configuration +WORKING_DIR = "/home/js/LightRAG/jrs/work/mod_linx/_mod_linx_work_dir" +EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text-embedding-3-large") +EMBEDDING_DIM = int(os.getenv("EMBEDDING_DIM", 3072)) +API_KEY = os.getenv("EMBEDDING_BINDING_API_KEY") +MAX_TOKEN_SIZE = int(os.getenv("MAX_TOKEN_SIZE", 8192)) + + +def configure_logging(): + """Configure logging with console and rotating file handlers.""" + for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]: + logger_instance = logging.getLogger(logger_name) + logger_instance.handlers = [] + logger_instance.filters = [] + log_dir = os.getenv("LOG_DIR", os.getcwd()) + log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag_query.log")) + print(f"\nLightRAG query log file: {log_file_path}\n") + os.makedirs(os.path.dirname(log_dir), exist_ok=True) + log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) + log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) + logging.config.dictConfig( + { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": {"format": "%(levelname)s: %(message)s"}, + "detailed": { + "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + }, + }, + "handlers": { + "console": { + "formatter": "default", + "class": "logging.StreamHandler", + "stream": "ext://sys.stderr", + }, + "file": { + "formatter": "detailed", + "class": "logging.handlers.RotatingFileHandler", + "filename": log_file_path, + "maxBytes": log_max_bytes, + "backupCount": log_backup_count, + "encoding": "utf-8", + }, + }, + "loggers": { + "lightrag": { + "handlers": ["console", "file"], + "level": "INFO", + "propagate": False, + }, + }, + } + ) + logger.setLevel(logging.INFO) + set_verbose_debug(os.getenv("VERBOSE_DEBUG", "true").lower() == "true") + + +async def initialize_rag(): + """Initialize LightRAG with custom embedding function.""" + print("Initializing LightRAG for querying...") + + # Initialize embedding model + embed_model = OpenAIEmbedding( + model=EMBEDDING_MODEL, api_key=API_KEY, dimensions=EMBEDDING_DIM + ) + + # Define async embedding function + async def async_embedding_func(texts): + return embed_model.get_text_embedding_batch(texts) + + # Define embedding function + embedding_func = EmbeddingFunc( + embedding_dim=EMBEDDING_DIM, + max_token_size=MAX_TOKEN_SIZE, + func=async_embedding_func, + ) + + # Initialize LightRAG + rag = LightRAG( + working_dir=WORKING_DIR, + embedding_func=embedding_func, + llm_model_func=gpt_4o_mini_complete, + ) + + await rag.initialize_storages() + await initialize_pipeline_status() + await rag.aclear_cache() + return rag + + +async def main(): + """Main function to query the LightRAG index.""" + rag = None + try: + if not os.getenv("OPENAI_API_KEY") and not API_KEY: + raise ValueError( + "OPENAI_API_KEY or EMBEDDING_BINDING_API_KEY environment variable not set" + ) + rag = await initialize_rag() + + # Check if index exists + if not os.path.exists(os.path.join(WORKING_DIR, "kv_store_full_docs.json")): + raise FileNotFoundError( + f"No index found in {WORKING_DIR}. Run the indexing script first." + ) + + # Perform query + + query = "What does wire 130 do and what is each end of the wire connected to?" + + for mode in [ + "naive", + "local", + "global", + "hybrid", + "mix", + ]: # "naive", "local", "global", "hybrid", "mix" + print("\n=====================") + print(f"Query mode: {mode}") + print("=====================") + response = await rag.aquery( + query, + param=QueryParam( + mode=mode, top_k=70 + ), # top_k=70, only_need_context=True, only_need_prompt=True + ) + print(response) + except Exception as e: + print(f"An error occurred: {e}") + import traceback + + traceback.print_exc() + finally: + if rag: + print("Finalizing storages...") + await rag.finalize_storages() + + +if __name__ == "__main__": + configure_logging() + asyncio.run(main()) + print("\nQuerying Done!") + + """ + print("--- All Loaded Environment Variables ---") + # os.environ is a dictionary-like object + # You can iterate over its items (key-value pairs) + for key, value in os.environ.items(): + print(f"{key}={value}") + + print("--------------------------------------") + """ diff --git a/jrs/archive/_2_ra_original_index_and_query.py b/jrs/archive/_2_ra_original_index_and_query.py new file mode 100644 index 0000000000..f7003621cc --- /dev/null +++ b/jrs/archive/_2_ra_original_index_and_query.py @@ -0,0 +1,389 @@ +#!/usr/bin/env python +""" +Example script demonstrating the integration of MinerU parser with RAGAnything + +This example shows how to: +1. Process parsed documents with RAGAnything +2. Perform multimodal queries on the processed documents +3. Handle different types of content (text, images, tables) +""" + +import os +import argparse +import asyncio +import logging +import logging.config +from pathlib import Path + +# Add project root directory to Python path +import sys + +sys.path.append(str(Path(__file__).parent.parent)) + +# Ensures the script can find local LightRAG modules if they aren't in site-packages +current_dir = Path(__file__).resolve().parent +sys.path.append(str(current_dir.parent)) + +from lightrag.llm.openai import openai_complete_if_cache, openai_embed # noqa: E402 +from lightrag.utils import EmbeddingFunc, logger, set_verbose_debug # noqa: E402 +from raganything import RAGAnything, RAGAnythingConfig # noqa: E402 + + +def configure_logging(): + """Configure logging for the application""" + # Get log directory path from environment variable or use current directory + log_dir = os.getenv("LOG_DIR", os.getcwd()) + log_file_path = os.path.abspath(os.path.join(log_dir, "raganything_example.log")) + + print(f"\nRAGAnything example log file: {log_file_path}\n") + os.makedirs(os.path.dirname(log_dir), exist_ok=True) + + # Get log file max size and backup count from environment variables + log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB + log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups + + logging.config.dictConfig( + { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": { + "format": "%(levelname)s: %(message)s", + }, + "detailed": { + "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", + }, + }, + "handlers": { + "console": { + "formatter": "default", + "class": "logging.StreamHandler", + "stream": "ext://sys.stderr", + }, + "file": { + "formatter": "detailed", + "class": "logging.handlers.RotatingFileHandler", + "filename": log_file_path, + "maxBytes": log_max_bytes, + "backupCount": log_backup_count, + "encoding": "utf-8", + }, + }, + "loggers": { + "lightrag": { + "handlers": ["console", "file"], + "level": "INFO", + "propagate": False, + }, + }, + } + ) + + # Set the logger level to INFO + logger.setLevel(logging.INFO) + # Enable verbose debug if needed + set_verbose_debug(os.getenv("VERBOSE", "false").lower() == "true") + + +async def process_with_rag( + file_path: str, + output_dir: str, + api_key: str, + base_url: str = None, + working_dir: str = None, +): + """ + Process document with RAGAnything + + Args: + file_path: Path to the document + output_dir: Output directory for RAG results + api_key: OpenAI API key + base_url: Optional base URL for API + working_dir: Working directory for RAG storage + """ + try: + # Create RAGAnything configuration + config = RAGAnythingConfig( + working_dir=working_dir or "./rag_storage", + enable_image_processing=True, + enable_table_processing=True, + enable_equation_processing=True, + ) + + # Define LLM model function + def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs): + return openai_complete_if_cache( + "gpt-4o-mini", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key=api_key, + base_url=base_url, + **kwargs, + ) + + # Define vision model function for image processing + def vision_model_func( + prompt, system_prompt=None, history_messages=[], image_data=None, **kwargs + ): + if image_data: + return openai_complete_if_cache( + "gpt-4o", + "", + system_prompt=None, + history_messages=[], + messages=[ + ( + {"role": "system", "content": system_prompt} + if system_prompt + else None + ), + ( + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_data}" + }, + }, + ], + } + if image_data + else {"role": "user", "content": prompt} + ), + ], + api_key=api_key, + base_url=base_url, + **kwargs, + ) + else: + return llm_model_func(prompt, system_prompt, history_messages, **kwargs) + + embedding_func = EmbeddingFunc( + embedding_dim=3072, + max_token_size=8192, + func=lambda texts: openai_embed( + texts, + model="text-embedding-3-large", + api_key=api_key, + base_url=base_url, + ), + ) + + # Initialize RAGAnything with new dataclass structure + rag = RAGAnything( + config=config, + llm_model_func=llm_model_func, + vision_model_func=vision_model_func, + embedding_func=embedding_func, + ) + + # Process document + await rag.process_document_complete( + file_path=file_path, output_dir=output_dir, parse_method="auto" + ) + + # Example queries - demonstrating different query approaches + logger.info("\nQuerying processed document:") + + # 1. Pure text queries using aquery() + text_queries = [ + "What is the main content of the document?", + "What are the key topics discussed?", + ] + + for query in text_queries: + logger.info(f"\n[Text Query]: {query}") + result = await rag.aquery(query, mode="hybrid") + logger.info(f"Answer: {result}") + + # logger.info( + # "\n[Multimodal Query]: Analyzing fabric attenuation chart data" + # ) + # multimodal_result = await rag.aquery_with_multimodal( + # "What fabric offers the least attenuation of near infrared light?", + # multimodal_content=[ + # { + # "type": "image", + # "image_data": "base64_encoded_image_data_here", # Base64 encoded image of Figure 7 + # "image_caption": "Figure 7: Point measurements showing drop in NIR intensity across multiple fabric layers", + # } + # ], + # mode="hybrid", + # ) + # logger.info(f"Answer: {multimodal_result}") + + # 2. Multimodal query with specific multimodal content using aquery_with_multimodal() + # logger.info( + # "\n[Multimodal Query]: Analyzing performance data in context of document" + # ) + # multimodal_result = await rag.aquery_with_multimodal( + # "Compare this performance data with any similar results mentioned in the document", + # multimodal_content=[ + # { + # "type": "table", + # "table_data": """Method,Accuracy,Processing_Time + # RAGAnything,95.2%,120ms + # Traditional_RAG,87.3%,180ms + # Baseline,82.1%,200ms""", + # "table_caption": "Performance comparison results", + # } + # ], + # mode="hybrid", + # ) + # logger.info(f"Answer: {multimodal_result}") + + # 3. Another multimodal query with equation content + # logger.info("\n[Multimodal Query]: Mathematical formula analysis") + # equation_result = await rag.aquery_with_multimodal( + # "Explain this formula and relate it to any mathematical concepts in the document", + # multimodal_content=[ + # { + # "type": "equation", + # "latex": "F1 = 2 \\cdot \\frac{precision \\cdot recall}{precision + recall}", + # "equation_caption": "F1-score calculation formula", + # } + # ], + # mode="hybrid", + # ) + # logger.info(f"Answer: {equation_result}") + + # Finalize RAGAnything storages. + if hasattr(rag, "close"): + # Try to call it normally; if it's a coroutine, it will be handled + result = rag.close() + if asyncio.iscoroutine(result): + await result + + logger.info("RAG processing and cleanup completed successfully.") + + except Exception as e: + logger.error(f"Error processing with RAG: {str(e)}") + import traceback + + logger.error(traceback.format_exc()) + + +def main(): + """Main function to run the example with flexible arguments""" + parser = argparse.ArgumentParser(description="MinerU RAG Example") + + # 1. Input Argument (The file to be indexed) + parser.add_argument( + "--file_path", + "-f", + default="/home/js/LightRAG_BACKUP/jrs/work/seheult/_ra/nir_through_fabrics/_ra_seheult_docs/nir_through_fabrics.pdf", + help="Path to the document to process", + ) + + # 2. Working Directory Arguments + parser.add_argument( + "--working_dir", + "-w", + default=os.getenv( + "RAG_WORKING_DIR", + "/home/js/LightRAG_BACKUP/jrs/work/seheult/_ra/nir_through_fabrics/_ra_seheult_output_dir", + ), + help="Working directory path", + ) + + # 3. Output Files Directory Arguments + parser.add_argument( + "--output", + "-o", + default=os.getenv( + "RAG_OUTPUT_DIR", + "/home/js/LightRAG_BACKUP/jrs/work/seheult/_ra/nir_through_fabrics/_ra_seheult_output_dir", + ), + help="Output directory path", + ) + + # 4. API Key Argument + parser.add_argument( + "--api-key", + default=os.getenv("OPENAI_API_KEY"), + help="OpenAI API key (defaults to OPENAI_API_KEY env var)", + ) + + # 5. Base URL Argument (Optional, for proxy or local LLM endpoints) + parser.add_argument( + "--base-url", + "-b", + default=os.getenv("OPENAI_BASE_URL"), + help="Optional base URL for API (e.g., https://api.openai.com/v1)", + ) + + parser.add_argument( + "--dry-run", + action="store_true", + help="Print configuration and exit without processing", + ) + + args = parser.parse_args() + + # Check if the input file exists + # We use .resolve() to handle relative paths like './my_input_file.pdf' + input_path = Path(args.file_path).resolve() + file_exists = input_path.exists() + + if args.dry_run: + print("\n=== DRY RUN MODE ===") + print(f"File to process: {args.file_path}") + print(f"Output Dir: {args.output}") + print(f"Working Dir: {args.working_dir}") + print(f"API Key: {'LOADED' if args.api_key else 'MISSING'}") + print(f"Base URL: {args.base_url}") + print("====================\n") + print("Configuration looks good. Remove --dry-run to start processing.") + file_exists = os.path.exists(args.file_path) + print( + f"Input File Exists?: {'YES, The input file exists.' if file_exists else 'NO (Check your file path!)'}" + ) + return + + # 3. Guard Clause: Stop the script if not a dry run and file is missing + if not file_exists: + print(f"\nFATAL ERROR: The file '{args.file_path}' does not exist.") + print(f"Resolved path: {input_path}") + print("Please provide a valid path using -f or --file_path.\n") + sys.exit(1) # Terminates the program immediately + + # Create directories (only if not a dry run) + os.makedirs(args.output, exist_ok=True) + os.makedirs(args.working_dir, exist_ok=True) + + # Priority Logic for file_path: Command line > Env Var > Error + file_to_process = args.file_path or os.getenv("RAG_FILE_PATH") + + if not file_to_process: + logger.error( + "Error: No file path provided via argument or RAG_FILE_PATH env var." + ) + return + + if not args.api_key: + logger.error("Error: OpenAI API key is required.") + return + + # Run the RAG process + asyncio.run( + process_with_rag( + args.file_path, args.output, args.api_key, args.base_url, args.working_dir + ) + ) + + +if __name__ == "__main__": + # Configure logging first + configure_logging() + + print("RAGAnything Example") + print("=" * 30) + print("Processing document with multimodal RAG pipeline") + print("=" * 30) + + main() diff --git a/jrs/archive/_2_ra_query_image_old01.py b/jrs/archive/_2_ra_query_image_old01.py new file mode 100644 index 0000000000..b8b766ad35 --- /dev/null +++ b/jrs/archive/_2_ra_query_image_old01.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python + +# For usage instructions enter the following command: +# python3 ~/LightRAG/jrs/_2_ra_query_image.py --help + +import os +import argparse +import asyncio + +from raganything import RAGAnything, RAGAnythingConfig +from lightrag.llm.openai import openai_complete_if_cache, openai_embed +from lightrag.utils import EmbeddingFunc + + +async def run_image_query( + query_text, api_key, base_url, working_dir, modes, output_file +): + try: + config = RAGAnythingConfig( + working_dir=working_dir, + enable_image_processing=True, + enable_table_processing=True, + ) + + # 1. Text LLM + def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs): + return openai_complete_if_cache( + "gpt-4o-mini", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key=api_key, + base_url=base_url, + **kwargs, + ) + + # 2. Vision LLM (Crucial for Image Queries) + def vision_model_func( + prompt, system_prompt=None, history_messages=[], image_data=None, **kwargs + ): + if image_data: + return openai_complete_if_cache( + "gpt-4o", + "", + system_prompt=None, + history_messages=[], + messages=[ + ( + {"role": "system", "content": system_prompt} + if system_prompt + else None + ), + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_data}" + }, + }, + ], + }, + ], + api_key=api_key, + base_url=base_url, + **kwargs, + ) + return llm_model_func(prompt, system_prompt, history_messages, **kwargs) + + embedding_func = EmbeddingFunc( + embedding_dim=3072, + max_token_size=8192, + func=lambda texts: openai_embed( + texts, + model="text-embedding-3-large", + api_key=api_key, + base_url=base_url, + ), + ) + + rag = RAGAnything( + config=config, + llm_model_func=llm_model_func, + vision_model_func=vision_model_func, + embedding_func=embedding_func, + ) + + print("INFO: Initializing Multimodal Engine...") + await rag._ensure_lightrag_initialized() + + # Prepare Markdown File + with open(output_file, "a", encoding="utf-8") as f: + f.write(f"\n# Query: {query_text}\n") + f.write(f"**Working Dir:** `{working_dir}`\n\n") + + # --- MULTI-MODE QUERY LOOP --- + for current_mode in modes: + print(f"\n>>> Executing [ {current_mode.upper()} ] mode...") + + try: + # query_with_multimodal is the method for reasoning over indexed visuals + result = await rag.aquery_with_multimodal(query_text, mode=current_mode) + + # Output to Console + print(f"\n[ {current_mode.upper()} ANSWER ]:") + print(f"{result}") + + # Output to Markdown File + with open(output_file, "a", encoding="utf-8") as f: + f.write(f"## Mode: {current_mode.upper()}\n") + f.write(f"{result}\n\n") + f.write("---\n") + + except Exception as e: + error_msg = f"Error in {current_mode} mode: {e}" + print(error_msg) + with open(output_file, "a", encoding="utf-8") as f: + f.write(f"### Mode: {current_mode.upper()} (FAILED)\n") + f.write(f"Error: {error_msg}\n\n") + + # Cleanup attempts (preserving existing logic) + if hasattr(rag, "finalize_storages"): + res = rag.finalize_storages() + if asyncio.iscoroutine(res): + await res + + if hasattr(rag, "lightrag") and rag.lightrag: + if hasattr(rag.lightrag, "storage") and hasattr( + rag.lightrag.storage, "close" + ): + await rag.lightrag.storage.close() + + del rag + + except Exception as e: + print(f"Query Error: {e}") + + +def main(): + parser = argparse.ArgumentParser(description="Multimodal Image Query Script") + parser.add_argument("query", help="Your question about the images/charts") + + # Modes parameter: Split by comma to allow multiple (e.g., -m naive,hybrid) + parser.add_argument( + "--modes", + "-m", + default="hybrid", + help="Comma-separated list of modes: naive,local,global,hybrid,mix", + ) + + # File parameter: Defaulting to LightRAG directory + parser.add_argument( + "--file", + "-f", + default="/home/js/LightRAG/mm_query_output.md", + help="Path to the output markdown file", + ) + + parser.add_argument( + "--working_dir", + "-w", + default="/home/js/LightRAG/jrs/work/seheult/_ra/nir_through_fabrics/_ra_seheult_work_dir", + help="Path to directory where index of knowledge is stored", + ) + + args = parser.parse_args() + + # Convert the comma-separated string into a clean Python list + mode_list = [m.strip().lower() for m in args.modes.split(",")] + + asyncio.run( + run_image_query( + args.query, + os.getenv("OPENAI_API_KEY"), + os.getenv("OPENAI_BASE_URL"), + args.working_dir, + mode_list, + args.file, + ) + ) + + +if __name__ == "__main__": + main() diff --git a/jrs/archive/_2_ra_query_image_old02.py b/jrs/archive/_2_ra_query_image_old02.py new file mode 100644 index 0000000000..4b3003f928 --- /dev/null +++ b/jrs/archive/_2_ra_query_image_old02.py @@ -0,0 +1,257 @@ +#!/usr/bin/env python + +# For usage instructions enter the following command: +# python3 ~/LightRAG/jrs/_2_ra_query_image.py --help + +import os +import argparse +import asyncio +from datetime import datetime + +from raganything import RAGAnything, RAGAnythingConfig +from lightrag.llm.openai import openai_complete_if_cache, openai_embed +from lightrag.utils import EmbeddingFunc + + +async def run_image_query( + query_text, api_key, base_url, working_dir, modes, output_file, query_params +): + try: + config = RAGAnythingConfig( + working_dir=working_dir, + enable_image_processing=True, + enable_table_processing=True, + ) + + def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs): + return openai_complete_if_cache( + "gpt-4o-mini", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key=api_key, + base_url=base_url, + **kwargs, + ) + + def vision_model_func( + prompt, system_prompt=None, history_messages=[], image_data=None, **kwargs + ): + if image_data: + return openai_complete_if_cache( + "gpt-4o", + "", + system_prompt=None, + history_messages=[], + messages=[ + ( + {"role": "system", "content": system_prompt} + if system_prompt + else None + ), + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_data}" + }, + }, + ], + }, + ], + api_key=api_key, + base_url=base_url, + **kwargs, + ) + return llm_model_func(prompt, system_prompt, history_messages, **kwargs) + + embedding_func = EmbeddingFunc( + embedding_dim=3072, + max_token_size=8192, + func=lambda texts: openai_embed( + texts, + model="text-embedding-3-large", + api_key=api_key, + base_url=base_url, + ), + ) + + rag = RAGAnything( + config=config, + llm_model_func=llm_model_func, + vision_model_func=vision_model_func, + embedding_func=embedding_func, + ) + + print("INFO: Initializing Multimodal Engine...") + await rag._ensure_lightrag_initialized() + + # Timestamp for the log + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + # Start the Markdown Log entry + with open(output_file, "a", encoding="utf-8") as f: + f.write(f"\n# Query Session: {timestamp}\n") + f.write(f"**Query:** `{query_text}`\n") + f.write(f"**Working Directory:** `{working_dir}`\n\n") + + # Create a Markdown Table for Parameters + f.write("### Session Parameters\n") + f.write("| Parameter | Value |\n") + f.write("| :--- | :--- |\n") + for param, value in query_params.items(): + f.write(f"| {param} | {value} |\n") + f.write(f"| modes_tested | {', '.join(modes)} |\n\n") + + # --- MULTI-MODE QUERY LOOP --- + for current_mode in modes: + print(f"\n>>> Executing [ {current_mode.upper()} ] mode...") + + # Local update for the current mode + query_params["mode"] = current_mode + + try: + result = await rag.aquery_with_multimodal(query_text, **query_params) + + # Output to Console + print(f"\n[ {current_mode.upper()} ANSWER ]:") + print(f"{result}") + + # Output to Markdown File + with open(output_file, "a", encoding="utf-8") as f: + f.write(f"## Analysis Mode: {current_mode.upper()}\n") + f.write(f"{result}\n\n") + f.write("---\n") + + except Exception as e: + error_msg = f"Error in {current_mode} mode: {e}" + print(error_msg) + with open(output_file, "a", encoding="utf-8") as f: + f.write(f"### Analysis Mode: {current_mode.upper()} (FAILED)\n") + f.write(f"**Error:** {error_msg}\n\n") + + # --- CLEANUP --- + if hasattr(rag, "finalize_storages"): + res = rag.finalize_storages() + if asyncio.iscoroutine(res): + await res + + if hasattr(rag, "lightrag") and rag.lightrag: + if hasattr(rag.lightrag, "storage") and hasattr( + rag.lightrag.storage, "close" + ): + await rag.lightrag.storage.close() + + del rag + + except Exception as e: + print(f"Query Error: {e}") + + +def main(): + description = """ +LightRAG Multimodal Query Tool +------------------------------ +This script performs advanced retrieval-augmented generation on indexed documents, +focusing on multimodal content. It logs all results and parameters to a Markdown file. +""" + epilog = ( + "For more detailed documentation, visit: file:///home/js/LightRAG/jrs/_notes" + ) + + parser = argparse.ArgumentParser( + description=description, + epilog=epilog, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument("query", help="Your question about the images/charts") + parser.add_argument( + "--modes", + "-m", + default="hybrid", + help="Comma-separated list of modes: naive,local,global,hybrid,mix,bypass", + ) + parser.add_argument( + "--output_file", + "-o", + default="/home/js/LightRAG/mm_query_output.md", + help="Output MD file where LLM response is found", + ) + parser.add_argument( + "--working_dir", + "-w", + default="/home/js/LightRAG/jrs/work/seheult/_ra/nir_through_fabrics/_ra_seheult_work_dir", + help="Index path", + ) + + # LightRAG Parameters + parser.add_argument( + "--response_type", default="Multiple Paragraphs", help="Response format" + ) + parser.add_argument("--top_k", type=int, default=60, help="Top items to retrieve") + parser.add_argument( + "--chunk_top_k", type=int, default=20, help="Initial text chunks" + ) + parser.add_argument( + "--max_entity_tokens", type=int, default=6000, help="Max entity tokens" + ) + parser.add_argument( + "--max_relation_tokens", type=int, default=8000, help="Max relation tokens" + ) + parser.add_argument( + "--max_total_tokens", type=int, default=30000, help="Total token budget" + ) + + # Flags + parser.add_argument( + "--only_context", action="store_true", help="Only return context" + ) + parser.add_argument("--only_prompt", action="store_true", help="Only return prompt") + parser.add_argument("--stream", action="store_true", help="Enable streaming") + parser.add_argument( + "--disable_rerank", + action="store_false", + dest="enable_rerank", + help="Disable reranking", + ) + + parser.add_argument("--user_prompt", help="Custom instructions for LLM") + + args = parser.parse_args() + + mode_list = [m.strip().lower() for m in args.modes.split(",")] + + # Dictionary to be unpacked as **kwargs + query_params = { + "only_need_context": args.only_context, + "only_need_prompt": args.only_prompt, + "response_type": args.response_type, + "stream": args.stream, + "top_k": args.top_k, + "chunk_top_k": args.chunk_top_k, + "max_entity_tokens": args.max_entity_tokens, + "max_relation_tokens": args.max_relation_tokens, + "max_total_tokens": args.max_total_tokens, + "user_prompt": args.user_prompt, + "enable_rerank": args.enable_rerank, + } + + asyncio.run( + run_image_query( + args.query, + os.getenv("OPENAI_API_KEY"), + os.getenv("OPENAI_BASE_URL"), + args.working_dir, + mode_list, + args.output_file, + query_params, + ) + ) + + +if __name__ == "__main__": + main() diff --git a/jrs/archive/_2_ra_query_image_old03.py b/jrs/archive/_2_ra_query_image_old03.py new file mode 100644 index 0000000000..1947c6561d --- /dev/null +++ b/jrs/archive/_2_ra_query_image_old03.py @@ -0,0 +1,265 @@ +#!/usr/bin/env python + +# For usage instructions enter the following command: +# python3 ~/LightRAG/jrs/_2_ra_query_image.py --help + +import os +import argparse +import asyncio +from datetime import datetime + +from raganything import RAGAnything, RAGAnythingConfig +from lightrag.llm.openai import openai_complete_if_cache, openai_embed +from lightrag.utils import EmbeddingFunc + + +async def run_image_query( + query_text, api_key, base_url, working_dir, modes, output_file, query_params +): + try: + config = RAGAnythingConfig( + working_dir=working_dir, + enable_image_processing=True, + enable_table_processing=True, + ) + + def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs): + return openai_complete_if_cache( + "gpt-4o-mini", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key=api_key, + base_url=base_url, + **kwargs, + ) + + def vision_model_func( + prompt, system_prompt=None, history_messages=[], image_data=None, **kwargs + ): + if image_data: + return openai_complete_if_cache( + "gpt-4o", + "", + system_prompt=None, + history_messages=[], + messages=[ + ( + {"role": "system", "content": system_prompt} + if system_prompt + else None + ), + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_data}" + }, + }, + ], + }, + ], + api_key=api_key, + base_url=base_url, + **kwargs, + ) + return llm_model_func(prompt, system_prompt, history_messages, **kwargs) + + embedding_func = EmbeddingFunc( + embedding_dim=3072, + max_token_size=8192, + func=lambda texts: openai_embed( + texts, + model="text-embedding-3-large", + api_key=api_key, + base_url=base_url, + ), + ) + + rag = RAGAnything( + config=config, + llm_model_func=llm_model_func, + vision_model_func=vision_model_func, + embedding_func=embedding_func, + ) + + print("INFO: Initializing Multimodal Engine...") + await rag._ensure_lightrag_initialized() + + # Timestamp for the log + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + # Start the Markdown Log entry + with open(output_file, "a", encoding="utf-8") as f: + f.write(f"\n# Query Session: {timestamp}\n") + f.write(f"**Query:** `{query_text}`\n") + f.write(f"**Working Directory:** `{working_dir}`\n\n") + + # Create a Markdown Table for Parameters + f.write("### Session Parameters\n") + f.write("| Parameter | Value |\n") + f.write("| :--- | :--- |\n") + for param, value in query_params.items(): + f.write(f"| {param} | {value} |\n") + f.write(f"| modes_tested | {', '.join(modes)} |\n\n") + + # --- MULTI-MODE QUERY LOOP --- + for current_mode in modes: + print(f"\n>>> Executing [ {current_mode.upper()} ] mode...") + + try: + # 1. Create a clean copy of your query parameters + run_params = query_params.copy() + + # 2. Remove 'mode' from the dictionary if it's there + # This prevents the "multiple values for keyword argument 'mode'" error + run_params.pop("mode", None) + + # 3. Call the method by unpacking the remaining parameters + # We pass the mode explicitly and then everything else via **run_params + result = await rag.aquery_with_multimodal( + query_text, mode=current_mode, **run_params + ) + + # Output to Console + print(f"\n[ {current_mode.upper()} ANSWER ]:") + print(f"{result}") + + # Output to Markdown File + with open(output_file, "a", encoding="utf-8") as f: + f.write(f"## Analysis Mode: {current_mode.upper()}\n") + f.write(f"{result}\n\n") + f.write("---\n") + + except Exception as e: + error_msg = f"Error in {current_mode} mode: {e}" + print(error_msg) + with open(output_file, "a", encoding="utf-8") as f: + f.write(f"### Analysis Mode: {current_mode.upper()} (FAILED)\n") + f.write(f"**Error:** {error_msg}\n\n") + + # --- CLEANUP --- + if hasattr(rag, "finalize_storages"): + res = rag.finalize_storages() + if asyncio.iscoroutine(res): + await res + + if hasattr(rag, "lightrag") and rag.lightrag: + if hasattr(rag.lightrag, "storage") and hasattr( + rag.lightrag.storage, "close" + ): + await rag.lightrag.storage.close() + + del rag + + except Exception as e: + print(f"Query Error: {e}") + + +def main(): + description = """ +LightRAG Multimodal Query Tool +------------------------------ +This script performs advanced retrieval-augmented generation on indexed documents, +focusing on multimodal content. It logs all results and parameters to a Markdown file. +""" + epilog = ( + "For more detailed documentation, visit: file:///home/js/LightRAG/jrs/_notes" + ) + + parser = argparse.ArgumentParser( + description=description, + epilog=epilog, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument("query", help="Your question about the images/charts") + parser.add_argument( + "--modes", + "-m", + default="hybrid", + help="Comma-separated list of modes: naive,local,global,hybrid,mix,bypass", + ) + parser.add_argument( + "--output_file", + "-o", + default="/home/js/LightRAG/mm_query_output.md", + help="Output MD file where LLM response is found", + ) + parser.add_argument( + "--working_dir", + "-w", + default="/home/js/LightRAG/jrs/work/seheult/_ra/nir_through_fabrics/_ra_seheult_work_dir", + help="Index path", + ) + + # LightRAG Parameters + parser.add_argument( + "--response_type", default="Multiple Paragraphs", help="Response format" + ) + parser.add_argument("--top_k", type=int, default=60, help="Top items to retrieve") + parser.add_argument( + "--chunk_top_k", type=int, default=20, help="Initial text chunks" + ) + parser.add_argument( + "--max_entity_tokens", type=int, default=6000, help="Max entity tokens" + ) + parser.add_argument( + "--max_relation_tokens", type=int, default=8000, help="Max relation tokens" + ) + parser.add_argument( + "--max_total_tokens", type=int, default=30000, help="Total token budget" + ) + + # Flags + parser.add_argument( + "--only_context", action="store_true", help="Only return context" + ) + parser.add_argument("--only_prompt", action="store_true", help="Only return prompt") + parser.add_argument("--stream", action="store_true", help="Enable streaming") + parser.add_argument( + "--disable_rerank", + action="store_false", + dest="enable_rerank", + help="Disable reranking", + ) + + parser.add_argument("--user_prompt", help="Custom instructions for LLM") + + args = parser.parse_args() + + mode_list = [m.strip().lower() for m in args.modes.split(",")] + + # Dictionary to be unpacked as **kwargs + query_params = { + "only_need_context": args.only_context, + "only_need_prompt": args.only_prompt, + "response_type": args.response_type, + "stream": args.stream, + "top_k": args.top_k, + "chunk_top_k": args.chunk_top_k, + "max_entity_tokens": args.max_entity_tokens, + "max_relation_tokens": args.max_relation_tokens, + "max_total_tokens": args.max_total_tokens, + "user_prompt": args.user_prompt, + "enable_rerank": args.enable_rerank, + } + + asyncio.run( + run_image_query( + args.query, + os.getenv("OPENAI_API_KEY"), + os.getenv("OPENAI_BASE_URL"), + args.working_dir, + mode_list, + args.output_file, + query_params, + ) + ) + + +if __name__ == "__main__": + main() diff --git a/jrs/archive/_2_ra_query_text_old01.py b/jrs/archive/_2_ra_query_text_old01.py new file mode 100644 index 0000000000..b3b700133c --- /dev/null +++ b/jrs/archive/_2_ra_query_text_old01.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python +import os +import argparse +import asyncio +import sys +from pathlib import Path + +from raganything import RAGAnything, RAGAnythingConfig +from lightrag.llm.openai import openai_complete_if_cache, openai_embed +from lightrag.utils import EmbeddingFunc + + +async def run_query(query_text, api_key, base_url, working_dir, mode): + try: + # 1. Setup Config + config = RAGAnythingConfig( + working_dir=working_dir, + enable_image_processing=True, + enable_table_processing=True, + enable_equation_processing=True, + ) + + # 2. Setup LLM Function + def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs): + return openai_complete_if_cache( + "gpt-4o-mini", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key=api_key, + base_url=base_url, + **kwargs, + ) + + # 3. Setup Embedding Function (must match indexing) + embedding_func = EmbeddingFunc( + embedding_dim=3072, + max_token_size=8192, + func=lambda texts: openai_embed( + texts, + model="text-embedding-3-large", + api_key=api_key, + base_url=base_url, + ), + ) + + # 4. Initialize RAGAnything Wrapper + rag = RAGAnything( + config=config, llm_model_func=llm_model_func, embedding_func=embedding_func + ) + + # --- THE FIXES --- + # 1. Use the correct internal initialization method with 'await' + print("INFO: Connecting to existing index...") + await rag._ensure_lightrag_initialized() + + # 2. Verify the 'lightrag' attribute is now populated + if not rag.lightrag: + raise RuntimeError( + f"Failed to load LightRAG from {working_dir}. Check if files exist there." + ) + + print(f"\n--- Querying [{mode} mode]: {query_text} ---") + + # 3. Perform the actual query + result = await rag.aquery(query_text, mode=mode) + print(f"\nANSWER:\n{result}\n") + + # 4. Safely finalize storages (handling both sync and async possibilities) + if hasattr(rag, "finalize_storages"): + res = rag.finalize_storages() + if asyncio.iscoroutine(res): + await res + print("INFO: Storage finalized successfully.") + + except Exception as e: + print(f"Query Error: {e}") + # print(traceback.format_exc()) # Uncomment if you need deeper debugging + + +def main(): + parser = argparse.ArgumentParser(description="RAG Query Script") + parser.add_argument("query", help="The question you want to ask") + parser.add_argument( + "--working_dir", + "-w", + default="/home/js/LightRAG/jrs/work/seheult/_ra/nir_through_fabrics/_ra_seheult_work_dir", + ) + parser.add_argument( + "--mode", + "-m", + default="hybrid", + choices=["naive", "local", "global", "hybrid", "mix"], + ) + parser.add_argument("--api-key", default=os.getenv("OPENAI_API_KEY")) + parser.add_argument("--base-url", "-b", default=os.getenv("OPENAI_BASE_URL")) + + args = parser.parse_args() + + # Verify directory exists before starting + if not Path(args.working_dir).exists(): + print(f"FATAL: Working directory {args.working_dir} not found.") + sys.exit(1) + + # Execute the async run + asyncio.run( + run_query(args.query, args.api_key, args.base_url, args.working_dir, args.mode) + ) + + +if __name__ == "__main__": + main() diff --git a/jrs/archive/_3_test_02_inspect_queryparam.py b/jrs/archive/_3_test_02_inspect_queryparam.py new file mode 100644 index 0000000000..3e1f806c52 --- /dev/null +++ b/jrs/archive/_3_test_02_inspect_queryparam.py @@ -0,0 +1,17 @@ +import lightrag +from lightrag import QueryParam +import inspect + +print("_____________") + +for name, param in inspect.signature(QueryParam.__init__).parameters.items(): + print(f"{name}: default={param.default}, kind={param.kind}") + +print("_____________") + + +print("LightRAG is loaded from:", lightrag.__file__) + +print("_____________") + +print(QueryParam.__annotations__) diff --git a/jrs/archive/_3_test_03_reformat_sources.py b/jrs/archive/_3_test_03_reformat_sources.py new file mode 100644 index 0000000000..e627297c01 --- /dev/null +++ b/jrs/archive/_3_test_03_reformat_sources.py @@ -0,0 +1,19 @@ +# Define your input text here +raw_text = r""" +[That's what metastasis effectively is.](https://www.youtube.com/watch?v=Ln3WszTq0uA&t=5782s) So we have everything completely +[5782.62 > 5789.82] (Jack) backward. And if you think about chemotherapeutic drugs, every single drug is targeted at mitosis +[5789.82 > 5795.66] (Jack) at some level. So if you really understand what I'm saying, we're doing everything wrong. +""" + +# Clean the lines +cleaned_lines = [] +for line in raw_text.strip().splitlines(): + cleaned_line = line.lstrip("\\") + " " + cleaned_lines.append(cleaned_line) + +# Output to a text file +output_filename = "_formatted_output.txt" +with open(output_filename, "w", encoding="utf-8") as f: + f.write("\n".join(cleaned_lines)) + +print(f"Formatted text written to '{output_filename}'") diff --git a/jrs/archive/_3_test_convert_to_lower_01.py b/jrs/archive/_3_test_convert_to_lower_01.py new file mode 100644 index 0000000000..34cc0b3bca --- /dev/null +++ b/jrs/archive/_3_test_convert_to_lower_01.py @@ -0,0 +1,14 @@ +# Convert all uppercase letters to lowercase in testin.txt and save to testout.txt + +# Read the content of testin.txt +with open("testin.txt", "r", encoding="utf-8") as infile: + content = infile.read() + +# Convert to lowercase +lowercase_content = content.lower() + +# Write the lowercase content to testout.txt +with open("testout.txt", "w", encoding="utf-8") as outfile: + outfile.write(lowercase_content) + +print("Conversion complete. Output written to testout.txt.") diff --git a/lightrag/api/routers/graph_routes.py b/lightrag/api/routers/graph_routes.py index e892ff011c..8bd47625cb 100644 --- a/lightrag/api/routers/graph_routes.py +++ b/lightrag/api/routers/graph_routes.py @@ -1,3 +1,5 @@ +# graph_routes.py + """ This module contains all graph-related routes for the LightRAG API. """ @@ -6,6 +8,9 @@ import traceback from fastapi import APIRouter, Depends, Query, HTTPException from pydantic import BaseModel, Field +from lightrag.kg.shared_storage import ( + set_all_update_flags_for_all_namespaces, +) # Import the new function from lightrag.utils import logger from ..utils_api import get_combined_auth_dependency @@ -685,4 +690,24 @@ async def merge_entities(request: EntityMergeRequest): status_code=500, detail=f"Error merging entities: {str(e)}" ) + @router.post("/graph/refresh-data", dependencies=[Depends(combined_auth)]) + async def refresh_graph_data_from_disk(): + """ + Triggers a refresh of all graph data from disk. + This signals all worker processes to reload their knowledge graph data. + """ + try: + await set_all_update_flags_for_all_namespaces() + logger.info("Triggered refresh of all graph data from disk.") + return { + "status": "success", + "message": "Graph data refresh triggered successfully.", + } + except Exception as e: + logger.error(f"Error triggering graph data refresh: {str(e)}") + logger.error(traceback.format_exc()) + raise HTTPException( + status_code=500, detail=f"Error triggering graph data refresh: {str(e)}" + ) + return router diff --git a/lightrag/api/routers/graph_routes_original.py b/lightrag/api/routers/graph_routes_original.py new file mode 100644 index 0000000000..e892ff011c --- /dev/null +++ b/lightrag/api/routers/graph_routes_original.py @@ -0,0 +1,688 @@ +""" +This module contains all graph-related routes for the LightRAG API. +""" + +from typing import Optional, Dict, Any +import traceback +from fastapi import APIRouter, Depends, Query, HTTPException +from pydantic import BaseModel, Field + +from lightrag.utils import logger +from ..utils_api import get_combined_auth_dependency + +router = APIRouter(tags=["graph"]) + + +class EntityUpdateRequest(BaseModel): + entity_name: str + updated_data: Dict[str, Any] + allow_rename: bool = False + allow_merge: bool = False + + +class RelationUpdateRequest(BaseModel): + source_id: str + target_id: str + updated_data: Dict[str, Any] + + +class EntityMergeRequest(BaseModel): + entities_to_change: list[str] = Field( + ..., + description="List of entity names to be merged and deleted. These are typically duplicate or misspelled entities.", + min_length=1, + examples=[["Elon Msk", "Ellon Musk"]], + ) + entity_to_change_into: str = Field( + ..., + description="Target entity name that will receive all relationships from the source entities. This entity will be preserved.", + min_length=1, + examples=["Elon Musk"], + ) + + +class EntityCreateRequest(BaseModel): + entity_name: str = Field( + ..., + description="Unique name for the new entity", + min_length=1, + examples=["Tesla"], + ) + entity_data: Dict[str, Any] = Field( + ..., + description="Dictionary containing entity properties. Common fields include 'description' and 'entity_type'.", + examples=[ + { + "description": "Electric vehicle manufacturer", + "entity_type": "ORGANIZATION", + } + ], + ) + + +class RelationCreateRequest(BaseModel): + source_entity: str = Field( + ..., + description="Name of the source entity. This entity must already exist in the knowledge graph.", + min_length=1, + examples=["Elon Musk"], + ) + target_entity: str = Field( + ..., + description="Name of the target entity. This entity must already exist in the knowledge graph.", + min_length=1, + examples=["Tesla"], + ) + relation_data: Dict[str, Any] = Field( + ..., + description="Dictionary containing relationship properties. Common fields include 'description', 'keywords', and 'weight'.", + examples=[ + { + "description": "Elon Musk is the CEO of Tesla", + "keywords": "CEO, founder", + "weight": 1.0, + } + ], + ) + + +def create_graph_routes(rag, api_key: Optional[str] = None): + combined_auth = get_combined_auth_dependency(api_key) + + @router.get("/graph/label/list", dependencies=[Depends(combined_auth)]) + async def get_graph_labels(): + """ + Get all graph labels + + Returns: + List[str]: List of graph labels + """ + try: + return await rag.get_graph_labels() + except Exception as e: + logger.error(f"Error getting graph labels: {str(e)}") + logger.error(traceback.format_exc()) + raise HTTPException( + status_code=500, detail=f"Error getting graph labels: {str(e)}" + ) + + @router.get("/graph/label/popular", dependencies=[Depends(combined_auth)]) + async def get_popular_labels( + limit: int = Query( + 300, description="Maximum number of popular labels to return", ge=1, le=1000 + ), + ): + """ + Get popular labels by node degree (most connected entities) + + Args: + limit (int): Maximum number of labels to return (default: 300, max: 1000) + + Returns: + List[str]: List of popular labels sorted by degree (highest first) + """ + try: + return await rag.chunk_entity_relation_graph.get_popular_labels(limit) + except Exception as e: + logger.error(f"Error getting popular labels: {str(e)}") + logger.error(traceback.format_exc()) + raise HTTPException( + status_code=500, detail=f"Error getting popular labels: {str(e)}" + ) + + @router.get("/graph/label/search", dependencies=[Depends(combined_auth)]) + async def search_labels( + q: str = Query(..., description="Search query string"), + limit: int = Query( + 50, description="Maximum number of search results to return", ge=1, le=100 + ), + ): + """ + Search labels with fuzzy matching + + Args: + q (str): Search query string + limit (int): Maximum number of results to return (default: 50, max: 100) + + Returns: + List[str]: List of matching labels sorted by relevance + """ + try: + return await rag.chunk_entity_relation_graph.search_labels(q, limit) + except Exception as e: + logger.error(f"Error searching labels with query '{q}': {str(e)}") + logger.error(traceback.format_exc()) + raise HTTPException( + status_code=500, detail=f"Error searching labels: {str(e)}" + ) + + @router.get("/graphs", dependencies=[Depends(combined_auth)]) + async def get_knowledge_graph( + label: str = Query(..., description="Label to get knowledge graph for"), + max_depth: int = Query(3, description="Maximum depth of graph", ge=1), + max_nodes: int = Query(1000, description="Maximum nodes to return", ge=1), + ): + """ + Retrieve a connected subgraph of nodes where the label includes the specified label. + When reducing the number of nodes, the prioritization criteria are as follows: + 1. Hops(path) to the staring node take precedence + 2. Followed by the degree of the nodes + + Args: + label (str): Label of the starting node + max_depth (int, optional): Maximum depth of the subgraph,Defaults to 3 + max_nodes: Maxiumu nodes to return + + Returns: + Dict[str, List[str]]: Knowledge graph for label + """ + try: + # Log the label parameter to check for leading spaces + logger.debug( + f"get_knowledge_graph called with label: '{label}' (length: {len(label)}, repr: {repr(label)})" + ) + + return await rag.get_knowledge_graph( + node_label=label, + max_depth=max_depth, + max_nodes=max_nodes, + ) + except Exception as e: + logger.error(f"Error getting knowledge graph for label '{label}': {str(e)}") + logger.error(traceback.format_exc()) + raise HTTPException( + status_code=500, detail=f"Error getting knowledge graph: {str(e)}" + ) + + @router.get("/graph/entity/exists", dependencies=[Depends(combined_auth)]) + async def check_entity_exists( + name: str = Query(..., description="Entity name to check"), + ): + """ + Check if an entity with the given name exists in the knowledge graph + + Args: + name (str): Name of the entity to check + + Returns: + Dict[str, bool]: Dictionary with 'exists' key indicating if entity exists + """ + try: + exists = await rag.chunk_entity_relation_graph.has_node(name) + return {"exists": exists} + except Exception as e: + logger.error(f"Error checking entity existence for '{name}': {str(e)}") + logger.error(traceback.format_exc()) + raise HTTPException( + status_code=500, detail=f"Error checking entity existence: {str(e)}" + ) + + @router.post("/graph/entity/edit", dependencies=[Depends(combined_auth)]) + async def update_entity(request: EntityUpdateRequest): + """ + Update an entity's properties in the knowledge graph + + This endpoint allows updating entity properties, including renaming entities. + When renaming to an existing entity name, the behavior depends on allow_merge: + + Args: + request (EntityUpdateRequest): Request containing: + - entity_name (str): Name of the entity to update + - updated_data (Dict[str, Any]): Dictionary of properties to update + - allow_rename (bool): Whether to allow entity renaming (default: False) + - allow_merge (bool): Whether to merge into existing entity when renaming + causes name conflict (default: False) + + Returns: + Dict with the following structure: + { + "status": "success", + "message": "Entity updated successfully" | "Entity merged successfully into 'target_name'", + "data": { + "entity_name": str, # Final entity name + "description": str, # Entity description + "entity_type": str, # Entity type + "source_id": str, # Source chunk IDs + ... # Other entity properties + }, + "operation_summary": { + "merged": bool, # Whether entity was merged into another + "merge_status": str, # "success" | "failed" | "not_attempted" + "merge_error": str | None, # Error message if merge failed + "operation_status": str, # "success" | "partial_success" | "failure" + "target_entity": str | None, # Target entity name if renaming/merging + "final_entity": str, # Final entity name after operation + "renamed": bool # Whether entity was renamed + } + } + + operation_status values explained: + - "success": All operations completed successfully + * For simple updates: entity properties updated + * For renames: entity renamed successfully + * For merges: non-name updates applied AND merge completed + + - "partial_success": Update succeeded but merge failed + * Non-name property updates were applied successfully + * Merge operation failed (entity not merged) + * Original entity still exists with updated properties + * Use merge_error for failure details + + - "failure": Operation failed completely + * If merge_status == "failed": Merge attempted but both update and merge failed + * If merge_status == "not_attempted": Regular update failed + * No changes were applied to the entity + + merge_status values explained: + - "success": Entity successfully merged into target entity + - "failed": Merge operation was attempted but failed + - "not_attempted": No merge was attempted (normal update/rename) + + Behavior when renaming to an existing entity: + - If allow_merge=False: Raises ValueError with 400 status (default behavior) + - If allow_merge=True: Automatically merges the source entity into the existing target entity, + preserving all relationships and applying non-name updates first + + Example Request (simple update): + POST /graph/entity/edit + { + "entity_name": "Tesla", + "updated_data": {"description": "Updated description"}, + "allow_rename": false, + "allow_merge": false + } + + Example Response (simple update success): + { + "status": "success", + "message": "Entity updated successfully", + "data": { ... }, + "operation_summary": { + "merged": false, + "merge_status": "not_attempted", + "merge_error": null, + "operation_status": "success", + "target_entity": null, + "final_entity": "Tesla", + "renamed": false + } + } + + Example Request (rename with auto-merge): + POST /graph/entity/edit + { + "entity_name": "Elon Msk", + "updated_data": { + "entity_name": "Elon Musk", + "description": "Corrected description" + }, + "allow_rename": true, + "allow_merge": true + } + + Example Response (merge success): + { + "status": "success", + "message": "Entity merged successfully into 'Elon Musk'", + "data": { ... }, + "operation_summary": { + "merged": true, + "merge_status": "success", + "merge_error": null, + "operation_status": "success", + "target_entity": "Elon Musk", + "final_entity": "Elon Musk", + "renamed": true + } + } + + Example Response (partial success - update succeeded but merge failed): + { + "status": "success", + "message": "Entity updated successfully", + "data": { ... }, # Data reflects updated "Elon Msk" entity + "operation_summary": { + "merged": false, + "merge_status": "failed", + "merge_error": "Target entity locked by another operation", + "operation_status": "partial_success", + "target_entity": "Elon Musk", + "final_entity": "Elon Msk", # Original entity still exists + "renamed": true + } + } + """ + try: + result = await rag.aedit_entity( + entity_name=request.entity_name, + updated_data=request.updated_data, + allow_rename=request.allow_rename, + allow_merge=request.allow_merge, + ) + + # Extract operation_summary from result, with fallback for backward compatibility + operation_summary = result.get( + "operation_summary", + { + "merged": False, + "merge_status": "not_attempted", + "merge_error": None, + "operation_status": "success", + "target_entity": None, + "final_entity": request.updated_data.get( + "entity_name", request.entity_name + ), + "renamed": request.updated_data.get( + "entity_name", request.entity_name + ) + != request.entity_name, + }, + ) + + # Separate entity data from operation_summary for clean response + entity_data = dict(result) + entity_data.pop("operation_summary", None) + + # Generate appropriate response message based on merge status + response_message = ( + f"Entity merged successfully into '{operation_summary['final_entity']}'" + if operation_summary.get("merged") + else "Entity updated successfully" + ) + return { + "status": "success", + "message": response_message, + "data": entity_data, + "operation_summary": operation_summary, + } + except ValueError as ve: + logger.error( + f"Validation error updating entity '{request.entity_name}': {str(ve)}" + ) + raise HTTPException(status_code=400, detail=str(ve)) + except Exception as e: + logger.error(f"Error updating entity '{request.entity_name}': {str(e)}") + logger.error(traceback.format_exc()) + raise HTTPException( + status_code=500, detail=f"Error updating entity: {str(e)}" + ) + + @router.post("/graph/relation/edit", dependencies=[Depends(combined_auth)]) + async def update_relation(request: RelationUpdateRequest): + """Update a relation's properties in the knowledge graph + + Args: + request (RelationUpdateRequest): Request containing source ID, target ID and updated data + + Returns: + Dict: Updated relation information + """ + try: + result = await rag.aedit_relation( + source_entity=request.source_id, + target_entity=request.target_id, + updated_data=request.updated_data, + ) + return { + "status": "success", + "message": "Relation updated successfully", + "data": result, + } + except ValueError as ve: + logger.error( + f"Validation error updating relation between '{request.source_id}' and '{request.target_id}': {str(ve)}" + ) + raise HTTPException(status_code=400, detail=str(ve)) + except Exception as e: + logger.error( + f"Error updating relation between '{request.source_id}' and '{request.target_id}': {str(e)}" + ) + logger.error(traceback.format_exc()) + raise HTTPException( + status_code=500, detail=f"Error updating relation: {str(e)}" + ) + + @router.post("/graph/entity/create", dependencies=[Depends(combined_auth)]) + async def create_entity(request: EntityCreateRequest): + """ + Create a new entity in the knowledge graph + + This endpoint creates a new entity node in the knowledge graph with the specified + properties. The system automatically generates vector embeddings for the entity + to enable semantic search and retrieval. + + Request Body: + entity_name (str): Unique name identifier for the entity + entity_data (dict): Entity properties including: + - description (str): Textual description of the entity + - entity_type (str): Category/type of the entity (e.g., PERSON, ORGANIZATION, LOCATION) + - source_id (str): Related chunk_id from which the description originates + - Additional custom properties as needed + + Response Schema: + { + "status": "success", + "message": "Entity 'Tesla' created successfully", + "data": { + "entity_name": "Tesla", + "description": "Electric vehicle manufacturer", + "entity_type": "ORGANIZATION", + "source_id": "chunk-123chunk-456" + ... (other entity properties) + } + } + + HTTP Status Codes: + 200: Entity created successfully + 400: Invalid request (e.g., missing required fields, duplicate entity) + 500: Internal server error + + Example Request: + POST /graph/entity/create + { + "entity_name": "Tesla", + "entity_data": { + "description": "Electric vehicle manufacturer", + "entity_type": "ORGANIZATION" + } + } + """ + try: + # Use the proper acreate_entity method which handles: + # - Graph lock for concurrency + # - Vector embedding creation in entities_vdb + # - Metadata population and defaults + # - Index consistency via _edit_entity_done + result = await rag.acreate_entity( + entity_name=request.entity_name, + entity_data=request.entity_data, + ) + + return { + "status": "success", + "message": f"Entity '{request.entity_name}' created successfully", + "data": result, + } + except ValueError as ve: + logger.error( + f"Validation error creating entity '{request.entity_name}': {str(ve)}" + ) + raise HTTPException(status_code=400, detail=str(ve)) + except Exception as e: + logger.error(f"Error creating entity '{request.entity_name}': {str(e)}") + logger.error(traceback.format_exc()) + raise HTTPException( + status_code=500, detail=f"Error creating entity: {str(e)}" + ) + + @router.post("/graph/relation/create", dependencies=[Depends(combined_auth)]) + async def create_relation(request: RelationCreateRequest): + """ + Create a new relationship between two entities in the knowledge graph + + This endpoint establishes an undirected relationship between two existing entities. + The provided source/target order is accepted for convenience, but the backend + stored edge is undirected and may be returned with the entities swapped. + Both entities must already exist in the knowledge graph. The system automatically + generates vector embeddings for the relationship to enable semantic search and graph traversal. + + Prerequisites: + - Both source_entity and target_entity must exist in the knowledge graph + - Use /graph/entity/create to create entities first if they don't exist + + Request Body: + source_entity (str): Name of the source entity (relationship origin) + target_entity (str): Name of the target entity (relationship destination) + relation_data (dict): Relationship properties including: + - description (str): Textual description of the relationship + - keywords (str): Comma-separated keywords describing the relationship type + - source_id (str): Related chunk_id from which the description originates + - weight (float): Relationship strength/importance (default: 1.0) + - Additional custom properties as needed + + Response Schema: + { + "status": "success", + "message": "Relation created successfully between 'Elon Musk' and 'Tesla'", + "data": { + "src_id": "Elon Musk", + "tgt_id": "Tesla", + "description": "Elon Musk is the CEO of Tesla", + "keywords": "CEO, founder", + "source_id": "chunk-123chunk-456" + "weight": 1.0, + ... (other relationship properties) + } + } + + HTTP Status Codes: + 200: Relationship created successfully + 400: Invalid request (e.g., missing entities, invalid data, duplicate relationship) + 500: Internal server error + + Example Request: + POST /graph/relation/create + { + "source_entity": "Elon Musk", + "target_entity": "Tesla", + "relation_data": { + "description": "Elon Musk is the CEO of Tesla", + "keywords": "CEO, founder", + "weight": 1.0 + } + } + """ + try: + # Use the proper acreate_relation method which handles: + # - Graph lock for concurrency + # - Entity existence validation + # - Duplicate relation checks + # - Vector embedding creation in relationships_vdb + # - Index consistency via _edit_relation_done + result = await rag.acreate_relation( + source_entity=request.source_entity, + target_entity=request.target_entity, + relation_data=request.relation_data, + ) + + return { + "status": "success", + "message": f"Relation created successfully between '{request.source_entity}' and '{request.target_entity}'", + "data": result, + } + except ValueError as ve: + logger.error( + f"Validation error creating relation between '{request.source_entity}' and '{request.target_entity}': {str(ve)}" + ) + raise HTTPException(status_code=400, detail=str(ve)) + except Exception as e: + logger.error( + f"Error creating relation between '{request.source_entity}' and '{request.target_entity}': {str(e)}" + ) + logger.error(traceback.format_exc()) + raise HTTPException( + status_code=500, detail=f"Error creating relation: {str(e)}" + ) + + @router.post("/graph/entities/merge", dependencies=[Depends(combined_auth)]) + async def merge_entities(request: EntityMergeRequest): + """ + Merge multiple entities into a single entity, preserving all relationships + + This endpoint consolidates duplicate or misspelled entities while preserving the entire + graph structure. It's particularly useful for cleaning up knowledge graphs after document + processing or correcting entity name variations. + + What the Merge Operation Does: + 1. Deletes the specified source entities from the knowledge graph + 2. Transfers all relationships from source entities to the target entity + 3. Intelligently merges duplicate relationships (if multiple sources have the same relationship) + 4. Updates vector embeddings for accurate retrieval and search + 5. Preserves the complete graph structure and connectivity + 6. Maintains relationship properties and metadata + + Use Cases: + - Fixing spelling errors in entity names (e.g., "Elon Msk" -> "Elon Musk") + - Consolidating duplicate entities discovered after document processing + - Merging name variations (e.g., "NY", "New York", "New York City") + - Cleaning up the knowledge graph for better query performance + - Standardizing entity names across the knowledge base + + Request Body: + entities_to_change (list[str]): List of entity names to be merged and deleted + entity_to_change_into (str): Target entity that will receive all relationships + + Response Schema: + { + "status": "success", + "message": "Successfully merged 2 entities into 'Elon Musk'", + "data": { + "merged_entity": "Elon Musk", + "deleted_entities": ["Elon Msk", "Ellon Musk"], + "relationships_transferred": 15, + ... (merge operation details) + } + } + + HTTP Status Codes: + 200: Entities merged successfully + 400: Invalid request (e.g., empty entity list, target entity doesn't exist) + 500: Internal server error + + Example Request: + POST /graph/entities/merge + { + "entities_to_change": ["Elon Msk", "Ellon Musk"], + "entity_to_change_into": "Elon Musk" + } + + Note: + - The target entity (entity_to_change_into) must exist in the knowledge graph + - Source entities will be permanently deleted after the merge + - This operation cannot be undone, so verify entity names before merging + """ + try: + result = await rag.amerge_entities( + source_entities=request.entities_to_change, + target_entity=request.entity_to_change_into, + ) + return { + "status": "success", + "message": f"Successfully merged {len(request.entities_to_change)} entities into '{request.entity_to_change_into}'", + "data": result, + } + except ValueError as ve: + logger.error( + f"Validation error merging entities {request.entities_to_change} into '{request.entity_to_change_into}': {str(ve)}" + ) + raise HTTPException(status_code=400, detail=str(ve)) + except Exception as e: + logger.error( + f"Error merging entities {request.entities_to_change} into '{request.entity_to_change_into}': {str(e)}" + ) + logger.error(traceback.format_exc()) + raise HTTPException( + status_code=500, detail=f"Error merging entities: {str(e)}" + ) + + return router diff --git a/lightrag/api/routers/query_routes.py b/lightrag/api/routers/query_routes.py index 1a9405f89e..ff9a27677b 100644 --- a/lightrag/api/routers/query_routes.py +++ b/lightrag/api/routers/query_routes.py @@ -47,7 +47,7 @@ class QueryRequest(BaseModel): ) chunk_top_k: Optional[int] = Field( - ge=1, + ge=0, default=None, description="Number of text chunks to retrieve initially from vector search and keep after reranking.", ) @@ -80,6 +80,11 @@ class QueryRequest(BaseModel): description="List of low-level keywords to refine retrieval focus. Leave empty to use the LLM to generate the keywords.", ) + only_kg_context: Optional[bool] = Field( + default=False, + description="If True, includes only knowledge graph data (entities and relations) in context, excluding all chunks.", + ) + conversation_history: Optional[List[Dict[str, Any]]] = Field( default=None, description="History messages are only sent to LLM for context, not used for retrieval. Format: [{'role': 'user/assistant', 'content': 'message'}].", diff --git a/lightrag/api/routers/query_routes_original.py b/lightrag/api/routers/query_routes_original.py new file mode 100644 index 0000000000..1a9405f89e --- /dev/null +++ b/lightrag/api/routers/query_routes_original.py @@ -0,0 +1,1159 @@ +""" +This module contains all query-related routes for the LightRAG API. +""" + +import json +from typing import Any, Dict, List, Literal, Optional +from fastapi import APIRouter, Depends, HTTPException +from lightrag.base import QueryParam +from lightrag.api.utils_api import get_combined_auth_dependency +from lightrag.utils import logger +from pydantic import BaseModel, Field, field_validator + +router = APIRouter(tags=["query"]) + + +class QueryRequest(BaseModel): + query: str = Field( + min_length=3, + description="The query text", + ) + + mode: Literal["local", "global", "hybrid", "naive", "mix", "bypass"] = Field( + default="mix", + description="Query mode", + ) + + only_need_context: Optional[bool] = Field( + default=None, + description="If True, only returns the retrieved context without generating a response.", + ) + + only_need_prompt: Optional[bool] = Field( + default=None, + description="If True, only returns the generated prompt without producing a response.", + ) + + response_type: Optional[str] = Field( + min_length=1, + default=None, + description="Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.", + ) + + top_k: Optional[int] = Field( + ge=1, + default=None, + description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.", + ) + + chunk_top_k: Optional[int] = Field( + ge=1, + default=None, + description="Number of text chunks to retrieve initially from vector search and keep after reranking.", + ) + + max_entity_tokens: Optional[int] = Field( + default=None, + description="Maximum number of tokens allocated for entity context in unified token control system.", + ge=1, + ) + + max_relation_tokens: Optional[int] = Field( + default=None, + description="Maximum number of tokens allocated for relationship context in unified token control system.", + ge=1, + ) + + max_total_tokens: Optional[int] = Field( + default=None, + description="Maximum total tokens budget for the entire query context (entities + relations + chunks + system prompt).", + ge=1, + ) + + hl_keywords: list[str] = Field( + default_factory=list, + description="List of high-level keywords to prioritize in retrieval. Leave empty to use the LLM to generate the keywords.", + ) + + ll_keywords: list[str] = Field( + default_factory=list, + description="List of low-level keywords to refine retrieval focus. Leave empty to use the LLM to generate the keywords.", + ) + + conversation_history: Optional[List[Dict[str, Any]]] = Field( + default=None, + description="History messages are only sent to LLM for context, not used for retrieval. Format: [{'role': 'user/assistant', 'content': 'message'}].", + ) + + user_prompt: Optional[str] = Field( + default=None, + description="User-provided prompt for the query. If provided, this will be used instead of the default value from prompt template.", + ) + + enable_rerank: Optional[bool] = Field( + default=None, + description="Enable reranking for retrieved text chunks. If True but no rerank model is configured, a warning will be issued. Default is True.", + ) + + include_references: Optional[bool] = Field( + default=True, + description="If True, includes reference list in responses. Affects /query and /query/stream endpoints. /query/data always includes references.", + ) + + include_chunk_content: Optional[bool] = Field( + default=False, + description="If True, includes actual chunk text content in references. Only applies when include_references=True. Useful for evaluation and debugging.", + ) + + stream: Optional[bool] = Field( + default=True, + description="If True, enables streaming output for real-time responses. Only affects /query/stream endpoint.", + ) + + @field_validator("query", mode="after") + @classmethod + def query_strip_after(cls, query: str) -> str: + return query.strip() + + @field_validator("conversation_history", mode="after") + @classmethod + def conversation_history_role_check( + cls, conversation_history: List[Dict[str, Any]] | None + ) -> List[Dict[str, Any]] | None: + if conversation_history is None: + return None + for msg in conversation_history: + if "role" not in msg: + raise ValueError("Each message must have a 'role' key.") + if not isinstance(msg["role"], str) or not msg["role"].strip(): + raise ValueError("Each message 'role' must be a non-empty string.") + return conversation_history + + def to_query_params(self, is_stream: bool) -> "QueryParam": + """Converts a QueryRequest instance into a QueryParam instance.""" + # Use Pydantic's `.model_dump(exclude_none=True)` to remove None values automatically + # Exclude API-level parameters that don't belong in QueryParam + request_data = self.model_dump( + exclude_none=True, exclude={"query", "include_chunk_content"} + ) + + # Ensure `mode` and `stream` are set explicitly + param = QueryParam(**request_data) + param.stream = is_stream + return param + + +class ReferenceItem(BaseModel): + """A single reference item in query responses.""" + + reference_id: str = Field(description="Unique reference identifier") + file_path: str = Field(description="Path to the source file") + content: Optional[List[str]] = Field( + default=None, + description="List of chunk contents from this file (only present when include_chunk_content=True)", + ) + + +class QueryResponse(BaseModel): + response: str = Field( + description="The generated response", + ) + references: Optional[List[ReferenceItem]] = Field( + default=None, + description="Reference list (Disabled when include_references=False, /query/data always includes references.)", + ) + + +class QueryDataResponse(BaseModel): + status: str = Field(description="Query execution status") + message: str = Field(description="Status message") + data: Dict[str, Any] = Field( + description="Query result data containing entities, relationships, chunks, and references" + ) + metadata: Dict[str, Any] = Field( + description="Query metadata including mode, keywords, and processing information" + ) + + +class StreamChunkResponse(BaseModel): + """Response model for streaming chunks in NDJSON format""" + + references: Optional[List[Dict[str, str]]] = Field( + default=None, + description="Reference list (only in first chunk when include_references=True)", + ) + response: Optional[str] = Field( + default=None, description="Response content chunk or complete response" + ) + error: Optional[str] = Field( + default=None, description="Error message if processing fails" + ) + + +def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): + combined_auth = get_combined_auth_dependency(api_key) + + @router.post( + "/query", + response_model=QueryResponse, + dependencies=[Depends(combined_auth)], + responses={ + 200: { + "description": "Successful RAG query response", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "response": { + "type": "string", + "description": "The generated response from the RAG system", + }, + "references": { + "type": "array", + "items": { + "type": "object", + "properties": { + "reference_id": {"type": "string"}, + "file_path": {"type": "string"}, + "content": { + "type": "array", + "items": {"type": "string"}, + "description": "List of chunk contents from this file (only included when include_chunk_content=True)", + }, + }, + }, + "description": "Reference list (only included when include_references=True)", + }, + }, + "required": ["response"], + }, + "examples": { + "with_references": { + "summary": "Response with references", + "description": "Example response when include_references=True", + "value": { + "response": "Artificial Intelligence (AI) is a branch of computer science that aims to create intelligent machines capable of performing tasks that typically require human intelligence, such as learning, reasoning, and problem-solving.", + "references": [ + { + "reference_id": "1", + "file_path": "/documents/ai_overview.pdf", + }, + { + "reference_id": "2", + "file_path": "/documents/machine_learning.txt", + }, + ], + }, + }, + "with_chunk_content": { + "summary": "Response with chunk content", + "description": "Example response when include_references=True and include_chunk_content=True. Note: content is an array of chunks from the same file.", + "value": { + "response": "Artificial Intelligence (AI) is a branch of computer science that aims to create intelligent machines capable of performing tasks that typically require human intelligence, such as learning, reasoning, and problem-solving.", + "references": [ + { + "reference_id": "1", + "file_path": "/documents/ai_overview.pdf", + "content": [ + "Artificial Intelligence (AI) represents a transformative field in computer science focused on creating systems that can perform tasks requiring human-like intelligence. These tasks include learning from experience, understanding natural language, recognizing patterns, and making decisions.", + "AI systems can be categorized into narrow AI, which is designed for specific tasks, and general AI, which aims to match human cognitive abilities across a wide range of domains.", + ], + }, + { + "reference_id": "2", + "file_path": "/documents/machine_learning.txt", + "content": [ + "Machine learning is a subset of AI that enables computers to learn and improve from experience without being explicitly programmed. It focuses on the development of algorithms that can access data and use it to learn for themselves." + ], + }, + ], + }, + }, + "without_references": { + "summary": "Response without references", + "description": "Example response when include_references=False", + "value": { + "response": "Artificial Intelligence (AI) is a branch of computer science that aims to create intelligent machines capable of performing tasks that typically require human intelligence, such as learning, reasoning, and problem-solving." + }, + }, + "different_modes": { + "summary": "Different query modes", + "description": "Examples of responses from different query modes", + "value": { + "local_mode": "Focuses on specific entities and their relationships", + "global_mode": "Provides broader context from relationship patterns", + "hybrid_mode": "Combines local and global approaches", + "naive_mode": "Simple vector similarity search", + "mix_mode": "Integrates knowledge graph and vector retrieval", + }, + }, + }, + } + }, + }, + 400: { + "description": "Bad Request - Invalid input parameters", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": {"detail": {"type": "string"}}, + }, + "example": { + "detail": "Query text must be at least 3 characters long" + }, + } + }, + }, + 500: { + "description": "Internal Server Error - Query processing failed", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": {"detail": {"type": "string"}}, + }, + "example": { + "detail": "Failed to process query: LLM service unavailable" + }, + } + }, + }, + }, + ) + async def query_text(request: QueryRequest): + """ + Comprehensive RAG query endpoint with non-streaming response. Parameter "stream" is ignored. + + This endpoint performs Retrieval-Augmented Generation (RAG) queries using various modes + to provide intelligent responses based on your knowledge base. + + **Query Modes:** + - **local**: Focuses on specific entities and their direct relationships + - **global**: Analyzes broader patterns and relationships across the knowledge graph + - **hybrid**: Combines local and global approaches for comprehensive results + - **naive**: Simple vector similarity search without knowledge graph + - **mix**: Integrates knowledge graph retrieval with vector search (recommended) + - **bypass**: Direct LLM query without knowledge retrieval + + conversation_history parameteris sent to LLM only, does not affect retrieval results. + + **Usage Examples:** + + Basic query: + ```json + { + "query": "What is machine learning?", + "mode": "mix" + } + ``` + + Bypass initial LLM call by providing high-level and low-level keywords: + ```json + { + "query": "What is Retrieval-Augmented-Generation?", + "hl_keywords": ["machine learning", "information retrieval", "natural language processing"], + "ll_keywords": ["retrieval augmented generation", "RAG", "knowledge base"], + "mode": "mix" + } + ``` + + Advanced query with references: + ```json + { + "query": "Explain neural networks", + "mode": "hybrid", + "include_references": true, + "response_type": "Multiple Paragraphs", + "top_k": 10 + } + ``` + + Conversation with history: + ```json + { + "query": "Can you give me more details?", + "conversation_history": [ + {"role": "user", "content": "What is AI?"}, + {"role": "assistant", "content": "AI is artificial intelligence..."} + ] + } + ``` + + Args: + request (QueryRequest): The request object containing query parameters: + - **query**: The question or prompt to process (min 3 characters) + - **mode**: Query strategy - "mix" recommended for best results + - **include_references**: Whether to include source citations + - **response_type**: Format preference (e.g., "Multiple Paragraphs") + - **top_k**: Number of top entities/relations to retrieve + - **conversation_history**: Previous dialogue context + - **max_total_tokens**: Token budget for the entire response + + Returns: + QueryResponse: JSON response containing: + - **response**: The generated answer to your query + - **references**: Source citations (if include_references=True) + + Raises: + HTTPException: + - 400: Invalid input parameters (e.g., query too short) + - 500: Internal processing error (e.g., LLM service unavailable) + """ + try: + param = request.to_query_params( + False + ) # Ensure stream=False for non-streaming endpoint + # Force stream=False for /query endpoint regardless of include_references setting + param.stream = False + + # Unified approach: always use aquery_llm for both cases + result = await rag.aquery_llm(request.query, param=param) + + # Extract LLM response and references from unified result + llm_response = result.get("llm_response", {}) + data = result.get("data", {}) + references = data.get("references", []) + + # Get the non-streaming response content + response_content = llm_response.get("content", "") + if not response_content: + response_content = "No relevant context found for the query." + + # Enrich references with chunk content if requested + if request.include_references and request.include_chunk_content: + chunks = data.get("chunks", []) + # Create a mapping from reference_id to chunk content + ref_id_to_content = {} + for chunk in chunks: + ref_id = chunk.get("reference_id", "") + content = chunk.get("content", "") + if ref_id and content: + # Collect chunk content; join later to avoid quadratic string concatenation + ref_id_to_content.setdefault(ref_id, []).append(content) + + # Add content to references + enriched_references = [] + for ref in references: + ref_copy = ref.copy() + ref_id = ref.get("reference_id", "") + if ref_id in ref_id_to_content: + # Keep content as a list of chunks (one file may have multiple chunks) + ref_copy["content"] = ref_id_to_content[ref_id] + enriched_references.append(ref_copy) + references = enriched_references + + # Return response with or without references based on request + if request.include_references: + return QueryResponse(response=response_content, references=references) + else: + return QueryResponse(response=response_content, references=None) + except Exception as e: + logger.error(f"Error processing query: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + @router.post( + "/query/stream", + dependencies=[Depends(combined_auth)], + responses={ + 200: { + "description": "Flexible RAG query response - format depends on stream parameter", + "content": { + "application/x-ndjson": { + "schema": { + "type": "string", + "format": "ndjson", + "description": "Newline-delimited JSON (NDJSON) format used for both streaming and non-streaming responses. For streaming: multiple lines with separate JSON objects. For non-streaming: single line with complete JSON object.", + "example": '{"references": [{"reference_id": "1", "file_path": "/documents/ai.pdf"}]}\n{"response": "Artificial Intelligence is"}\n{"response": " a field of computer science"}\n{"response": " that focuses on creating intelligent machines."}', + }, + "examples": { + "streaming_with_references": { + "summary": "Streaming mode with references (stream=true)", + "description": "Multiple NDJSON lines when stream=True and include_references=True. First line contains references, subsequent lines contain response chunks.", + "value": '{"references": [{"reference_id": "1", "file_path": "/documents/ai_overview.pdf"}, {"reference_id": "2", "file_path": "/documents/ml_basics.txt"}]}\n{"response": "Artificial Intelligence (AI) is a branch of computer science"}\n{"response": " that aims to create intelligent machines capable of performing"}\n{"response": " tasks that typically require human intelligence, such as learning,"}\n{"response": " reasoning, and problem-solving."}', + }, + "streaming_with_chunk_content": { + "summary": "Streaming mode with chunk content (stream=true, include_chunk_content=true)", + "description": "Multiple NDJSON lines when stream=True, include_references=True, and include_chunk_content=True. First line contains references with content arrays (one file may have multiple chunks), subsequent lines contain response chunks.", + "value": '{"references": [{"reference_id": "1", "file_path": "/documents/ai_overview.pdf", "content": ["Artificial Intelligence (AI) represents a transformative field...", "AI systems can be categorized into narrow AI and general AI..."]}, {"reference_id": "2", "file_path": "/documents/ml_basics.txt", "content": ["Machine learning is a subset of AI that enables computers to learn..."]}]}\n{"response": "Artificial Intelligence (AI) is a branch of computer science"}\n{"response": " that aims to create intelligent machines capable of performing"}\n{"response": " tasks that typically require human intelligence."}', + }, + "streaming_without_references": { + "summary": "Streaming mode without references (stream=true)", + "description": "Multiple NDJSON lines when stream=True and include_references=False. Only response chunks are sent.", + "value": '{"response": "Machine learning is a subset of artificial intelligence"}\n{"response": " that enables computers to learn and improve from experience"}\n{"response": " without being explicitly programmed for every task."}', + }, + "non_streaming_with_references": { + "summary": "Non-streaming mode with references (stream=false)", + "description": "Single NDJSON line when stream=False and include_references=True. Complete response with references in one message.", + "value": '{"references": [{"reference_id": "1", "file_path": "/documents/neural_networks.pdf"}], "response": "Neural networks are computational models inspired by biological neural networks that consist of interconnected nodes (neurons) organized in layers. They are fundamental to deep learning and can learn complex patterns from data through training processes."}', + }, + "non_streaming_without_references": { + "summary": "Non-streaming mode without references (stream=false)", + "description": "Single NDJSON line when stream=False and include_references=False. Complete response only.", + "value": '{"response": "Deep learning is a subset of machine learning that uses neural networks with multiple layers (hence deep) to model and understand complex patterns in data. It has revolutionized fields like computer vision, natural language processing, and speech recognition."}', + }, + "error_response": { + "summary": "Error during streaming", + "description": "Error handling in NDJSON format when an error occurs during processing.", + "value": '{"references": [{"reference_id": "1", "file_path": "/documents/ai.pdf"}]}\n{"response": "Artificial Intelligence is"}\n{"error": "LLM service temporarily unavailable"}', + }, + }, + } + }, + }, + 400: { + "description": "Bad Request - Invalid input parameters", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": {"detail": {"type": "string"}}, + }, + "example": { + "detail": "Query text must be at least 3 characters long" + }, + } + }, + }, + 500: { + "description": "Internal Server Error - Query processing failed", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": {"detail": {"type": "string"}}, + }, + "example": { + "detail": "Failed to process streaming query: Knowledge graph unavailable" + }, + } + }, + }, + }, + ) + async def query_text_stream(request: QueryRequest): + """ + Advanced RAG query endpoint with flexible streaming response. + + This endpoint provides the most flexible querying experience, supporting both real-time streaming + and complete response delivery based on your integration needs. + + **Response Modes:** + - Real-time response delivery as content is generated + - NDJSON format: each line is a separate JSON object + - First line: `{"references": [...]}` (if include_references=True) + - Subsequent lines: `{"response": "content chunk"}` + - Error handling: `{"error": "error message"}` + + > If stream parameter is False, or the query hit LLM cache, complete response delivered in a single streaming message. + + **Response Format Details** + - **Content-Type**: `application/x-ndjson` (Newline-Delimited JSON) + - **Structure**: Each line is an independent, valid JSON object + - **Parsing**: Process line-by-line, each line is self-contained + - **Headers**: Includes cache control and connection management + + **Query Modes (same as /query endpoint)** + - **local**: Entity-focused retrieval with direct relationships + - **global**: Pattern analysis across the knowledge graph + - **hybrid**: Combined local and global strategies + - **naive**: Vector similarity search only + - **mix**: Integrated knowledge graph + vector retrieval (recommended) + - **bypass**: Direct LLM query without knowledge retrieval + + conversation_history parameteris sent to LLM only, does not affect retrieval results. + + **Usage Examples** + + Real-time streaming query: + ```json + { + "query": "Explain machine learning algorithms", + "mode": "mix", + "stream": true, + "include_references": true + } + ``` + + Bypass initial LLM call by providing high-level and low-level keywords: + ```json + { + "query": "What is Retrieval-Augmented-Generation?", + "hl_keywords": ["machine learning", "information retrieval", "natural language processing"], + "ll_keywords": ["retrieval augmented generation", "RAG", "knowledge base"], + "mode": "mix" + } + ``` + + Complete response query: + ```json + { + "query": "What is deep learning?", + "mode": "hybrid", + "stream": false, + "response_type": "Multiple Paragraphs" + } + ``` + + Conversation with context: + ```json + { + "query": "Can you elaborate on that?", + "stream": true, + "conversation_history": [ + {"role": "user", "content": "What is neural network?"}, + {"role": "assistant", "content": "A neural network is..."} + ] + } + ``` + + **Response Processing:** + + ```python + async for line in response.iter_lines(): + data = json.loads(line) + if "references" in data: + # Handle references (first message) + references = data["references"] + if "response" in data: + # Handle content chunk + content_chunk = data["response"] + if "error" in data: + # Handle error + error_message = data["error"] + ``` + + **Error Handling:** + - Streaming errors are delivered as `{"error": "message"}` lines + - Non-streaming errors raise HTTP exceptions + - Partial responses may be delivered before errors in streaming mode + - Always check for error objects when processing streaming responses + + Args: + request (QueryRequest): The request object containing query parameters: + - **query**: The question or prompt to process (min 3 characters) + - **mode**: Query strategy - "mix" recommended for best results + - **stream**: Enable streaming (True) or complete response (False) + - **include_references**: Whether to include source citations + - **response_type**: Format preference (e.g., "Multiple Paragraphs") + - **top_k**: Number of top entities/relations to retrieve + - **conversation_history**: Previous dialogue context for multi-turn conversations + - **max_total_tokens**: Token budget for the entire response + + Returns: + StreamingResponse: NDJSON streaming response containing: + - **Streaming mode**: Multiple JSON objects, one per line + - References object (if requested): `{"references": [...]}` + - Content chunks: `{"response": "chunk content"}` + - Error objects: `{"error": "error message"}` + - **Non-streaming mode**: Single JSON object + - Complete response: `{"references": [...], "response": "complete content"}` + + Raises: + HTTPException: + - 400: Invalid input parameters (e.g., query too short, invalid mode) + - 500: Internal processing error (e.g., LLM service unavailable) + + Note: + This endpoint is ideal for applications requiring flexible response delivery. + Use streaming mode for real-time interfaces and non-streaming for batch processing. + """ + try: + # Use the stream parameter from the request, defaulting to True if not specified + stream_mode = request.stream if request.stream is not None else True + param = request.to_query_params(stream_mode) + + from fastapi.responses import StreamingResponse + + # Unified approach: always use aquery_llm for all cases + result = await rag.aquery_llm(request.query, param=param) + + async def stream_generator(): + # Extract references and LLM response from unified result + references = result.get("data", {}).get("references", []) + llm_response = result.get("llm_response", {}) + + # Enrich references with chunk content if requested + if request.include_references and request.include_chunk_content: + data = result.get("data", {}) + chunks = data.get("chunks", []) + # Create a mapping from reference_id to chunk content + ref_id_to_content = {} + for chunk in chunks: + ref_id = chunk.get("reference_id", "") + content = chunk.get("content", "") + if ref_id and content: + # Collect chunk content + ref_id_to_content.setdefault(ref_id, []).append(content) + + # Add content to references + enriched_references = [] + for ref in references: + ref_copy = ref.copy() + ref_id = ref.get("reference_id", "") + if ref_id in ref_id_to_content: + # Keep content as a list of chunks (one file may have multiple chunks) + ref_copy["content"] = ref_id_to_content[ref_id] + enriched_references.append(ref_copy) + references = enriched_references + + if llm_response.get("is_streaming"): + # Streaming mode: send references first, then stream response chunks + if request.include_references: + yield f"{json.dumps({'references': references})}\n" + + response_stream = llm_response.get("response_iterator") + if response_stream: + try: + async for chunk in response_stream: + if chunk: # Only send non-empty content + yield f"{json.dumps({'response': chunk})}\n" + except Exception as e: + logger.error(f"Streaming error: {str(e)}") + yield f"{json.dumps({'error': str(e)})}\n" + else: + # Non-streaming mode: send complete response in one message + response_content = llm_response.get("content", "") + if not response_content: + response_content = "No relevant context found for the query." + + # Create complete response object + complete_response = {"response": response_content} + if request.include_references: + complete_response["references"] = references + + yield f"{json.dumps(complete_response)}\n" + + return StreamingResponse( + stream_generator(), + media_type="application/x-ndjson", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "application/x-ndjson", + "X-Accel-Buffering": "no", # Ensure proper handling of streaming response when proxied by Nginx + }, + ) + except Exception as e: + logger.error(f"Error processing streaming query: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + @router.post( + "/query/data", + response_model=QueryDataResponse, + dependencies=[Depends(combined_auth)], + responses={ + 200: { + "description": "Successful data retrieval response with structured RAG data", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "status": { + "type": "string", + "enum": ["success", "failure"], + "description": "Query execution status", + }, + "message": { + "type": "string", + "description": "Status message describing the result", + }, + "data": { + "type": "object", + "properties": { + "entities": { + "type": "array", + "items": { + "type": "object", + "properties": { + "entity_name": {"type": "string"}, + "entity_type": {"type": "string"}, + "description": {"type": "string"}, + "source_id": {"type": "string"}, + "file_path": {"type": "string"}, + "reference_id": {"type": "string"}, + }, + }, + "description": "Retrieved entities from knowledge graph", + }, + "relationships": { + "type": "array", + "items": { + "type": "object", + "properties": { + "src_id": {"type": "string"}, + "tgt_id": {"type": "string"}, + "description": {"type": "string"}, + "keywords": {"type": "string"}, + "weight": {"type": "number"}, + "source_id": {"type": "string"}, + "file_path": {"type": "string"}, + "reference_id": {"type": "string"}, + }, + }, + "description": "Retrieved relationships from knowledge graph", + }, + "chunks": { + "type": "array", + "items": { + "type": "object", + "properties": { + "content": {"type": "string"}, + "file_path": {"type": "string"}, + "chunk_id": {"type": "string"}, + "reference_id": {"type": "string"}, + }, + }, + "description": "Retrieved text chunks from vector database", + }, + "references": { + "type": "array", + "items": { + "type": "object", + "properties": { + "reference_id": {"type": "string"}, + "file_path": {"type": "string"}, + }, + }, + "description": "Reference list for citation purposes", + }, + }, + "description": "Structured retrieval data containing entities, relationships, chunks, and references", + }, + "metadata": { + "type": "object", + "properties": { + "query_mode": {"type": "string"}, + "keywords": { + "type": "object", + "properties": { + "high_level": { + "type": "array", + "items": {"type": "string"}, + }, + "low_level": { + "type": "array", + "items": {"type": "string"}, + }, + }, + }, + "processing_info": { + "type": "object", + "properties": { + "total_entities_found": { + "type": "integer" + }, + "total_relations_found": { + "type": "integer" + }, + "entities_after_truncation": { + "type": "integer" + }, + "relations_after_truncation": { + "type": "integer" + }, + "final_chunks_count": { + "type": "integer" + }, + }, + }, + }, + "description": "Query metadata including mode, keywords, and processing information", + }, + }, + "required": ["status", "message", "data", "metadata"], + }, + "examples": { + "successful_local_mode": { + "summary": "Local mode data retrieval", + "description": "Example of structured data from local mode query focusing on specific entities", + "value": { + "status": "success", + "message": "Query executed successfully", + "data": { + "entities": [ + { + "entity_name": "Neural Networks", + "entity_type": "CONCEPT", + "description": "Computational models inspired by biological neural networks", + "source_id": "chunk-123", + "file_path": "/documents/ai_basics.pdf", + "reference_id": "1", + } + ], + "relationships": [ + { + "src_id": "Neural Networks", + "tgt_id": "Machine Learning", + "description": "Neural networks are a subset of machine learning algorithms", + "keywords": "subset, algorithm, learning", + "weight": 0.85, + "source_id": "chunk-123", + "file_path": "/documents/ai_basics.pdf", + "reference_id": "1", + } + ], + "chunks": [ + { + "content": "Neural networks are computational models that mimic the way biological neural networks work...", + "file_path": "/documents/ai_basics.pdf", + "chunk_id": "chunk-123", + "reference_id": "1", + } + ], + "references": [ + { + "reference_id": "1", + "file_path": "/documents/ai_basics.pdf", + } + ], + }, + "metadata": { + "query_mode": "local", + "keywords": { + "high_level": ["neural", "networks"], + "low_level": [ + "computation", + "model", + "algorithm", + ], + }, + "processing_info": { + "total_entities_found": 5, + "total_relations_found": 3, + "entities_after_truncation": 1, + "relations_after_truncation": 1, + "final_chunks_count": 1, + }, + }, + }, + }, + "global_mode": { + "summary": "Global mode data retrieval", + "description": "Example of structured data from global mode query analyzing broader patterns", + "value": { + "status": "success", + "message": "Query executed successfully", + "data": { + "entities": [], + "relationships": [ + { + "src_id": "Artificial Intelligence", + "tgt_id": "Machine Learning", + "description": "AI encompasses machine learning as a core component", + "keywords": "encompasses, component, field", + "weight": 0.92, + "source_id": "chunk-456", + "file_path": "/documents/ai_overview.pdf", + "reference_id": "2", + } + ], + "chunks": [], + "references": [ + { + "reference_id": "2", + "file_path": "/documents/ai_overview.pdf", + } + ], + }, + "metadata": { + "query_mode": "global", + "keywords": { + "high_level": [ + "artificial", + "intelligence", + "overview", + ], + "low_level": [], + }, + }, + }, + }, + "naive_mode": { + "summary": "Naive mode data retrieval", + "description": "Example of structured data from naive mode using only vector search", + "value": { + "status": "success", + "message": "Query executed successfully", + "data": { + "entities": [], + "relationships": [], + "chunks": [ + { + "content": "Deep learning is a subset of machine learning that uses neural networks with multiple layers...", + "file_path": "/documents/deep_learning.pdf", + "chunk_id": "chunk-789", + "reference_id": "3", + } + ], + "references": [ + { + "reference_id": "3", + "file_path": "/documents/deep_learning.pdf", + } + ], + }, + "metadata": { + "query_mode": "naive", + "keywords": {"high_level": [], "low_level": []}, + }, + }, + }, + }, + } + }, + }, + 400: { + "description": "Bad Request - Invalid input parameters", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": {"detail": {"type": "string"}}, + }, + "example": { + "detail": "Query text must be at least 3 characters long" + }, + } + }, + }, + 500: { + "description": "Internal Server Error - Data retrieval failed", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": {"detail": {"type": "string"}}, + }, + "example": { + "detail": "Failed to retrieve data: Knowledge graph unavailable" + }, + } + }, + }, + }, + ) + async def query_data(request: QueryRequest): + """ + Advanced data retrieval endpoint for structured RAG analysis. + + This endpoint provides raw retrieval results without LLM generation, perfect for: + - **Data Analysis**: Examine what information would be used for RAG + - **System Integration**: Get structured data for custom processing + - **Debugging**: Understand retrieval behavior and quality + - **Research**: Analyze knowledge graph structure and relationships + + **Key Features:** + - No LLM generation - pure data retrieval + - Complete structured output with entities, relationships, and chunks + - Always includes references for citation + - Detailed metadata about processing and keywords + - Compatible with all query modes and parameters + + **Query Mode Behaviors:** + - **local**: Returns entities and their direct relationships + related chunks + - **global**: Returns relationship patterns across the knowledge graph + - **hybrid**: Combines local and global retrieval strategies + - **naive**: Returns only vector-retrieved text chunks (no knowledge graph) + - **mix**: Integrates knowledge graph data with vector-retrieved chunks + - **bypass**: Returns empty data arrays (used for direct LLM queries) + + **Data Structure:** + - **entities**: Knowledge graph entities with descriptions and metadata + - **relationships**: Connections between entities with weights and descriptions + - **chunks**: Text segments from documents with source information + - **references**: Citation information mapping reference IDs to file paths + - **metadata**: Processing information, keywords, and query statistics + + **Usage Examples:** + + Analyze entity relationships: + ```json + { + "query": "machine learning algorithms", + "mode": "local", + "top_k": 10 + } + ``` + + Explore global patterns: + ```json + { + "query": "artificial intelligence trends", + "mode": "global", + "max_relation_tokens": 2000 + } + ``` + + Vector similarity search: + ```json + { + "query": "neural network architectures", + "mode": "naive", + "chunk_top_k": 5 + } + ``` + + Bypass initial LLM call by providing high-level and low-level keywords: + ```json + { + "query": "What is Retrieval-Augmented-Generation?", + "hl_keywords": ["machine learning", "information retrieval", "natural language processing"], + "ll_keywords": ["retrieval augmented generation", "RAG", "knowledge base"], + "mode": "mix" + } + ``` + + **Response Analysis:** + - **Empty arrays**: Normal for certain modes (e.g., naive mode has no entities/relationships) + - **Processing info**: Shows retrieval statistics and token usage + - **Keywords**: High-level and low-level keywords extracted from query + - **Reference mapping**: Links all data back to source documents + + Args: + request (QueryRequest): The request object containing query parameters: + - **query**: The search query to analyze (min 3 characters) + - **mode**: Retrieval strategy affecting data types returned + - **top_k**: Number of top entities/relationships to retrieve + - **chunk_top_k**: Number of text chunks to retrieve + - **max_entity_tokens**: Token limit for entity context + - **max_relation_tokens**: Token limit for relationship context + - **max_total_tokens**: Overall token budget for retrieval + + Returns: + QueryDataResponse: Structured JSON response containing: + - **status**: "success" or "failure" + - **message**: Human-readable status description + - **data**: Complete retrieval results with entities, relationships, chunks, references + - **metadata**: Query processing information and statistics + + Raises: + HTTPException: + - 400: Invalid input parameters (e.g., query too short, invalid mode) + - 500: Internal processing error (e.g., knowledge graph unavailable) + + Note: + This endpoint always includes references regardless of the include_references parameter, + as structured data analysis typically requires source attribution. + """ + try: + param = request.to_query_params(False) # No streaming for data endpoint + response = await rag.aquery_data(request.query, param=param) + + # aquery_data returns the new format with status, message, data, and metadata + if isinstance(response, dict): + return QueryDataResponse(**response) + else: + # Handle unexpected response format + return QueryDataResponse( + status="failure", + message="Invalid response type", + data={}, + ) + except Exception as e: + logger.error(f"Error processing data query: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + return router diff --git a/lightrag/base.py b/lightrag/base.py index b6380e7eb5..b4eeea6309 100644 --- a/lightrag/base.py +++ b/lightrag/base.py @@ -168,6 +168,9 @@ class QueryParam: containing citation information for the retrieved content. """ + only_kg_context: bool = False + """If True, includes only knowledge graph (entities and relations) in context, excluding all chunks.""" + @dataclass class StorageNameSpace(ABC): diff --git a/lightrag/base_original.py b/lightrag/base_original.py new file mode 100644 index 0000000000..bfbeca2133 --- /dev/null +++ b/lightrag/base_original.py @@ -0,0 +1,907 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from enum import Enum +import os +from dotenv import load_dotenv +from dataclasses import dataclass, field +from typing import ( + Any, + Literal, + TypedDict, + TypeVar, + Callable, + Optional, + Dict, + List, + AsyncIterator, +) +from .utils import EmbeddingFunc +from .types import KnowledgeGraph +from .constants import ( + DEFAULT_TOP_K, + DEFAULT_CHUNK_TOP_K, + DEFAULT_MAX_ENTITY_TOKENS, + DEFAULT_MAX_RELATION_TOKENS, + DEFAULT_MAX_TOTAL_TOKENS, + DEFAULT_HISTORY_TURNS, + DEFAULT_OLLAMA_MODEL_NAME, + DEFAULT_OLLAMA_MODEL_TAG, + DEFAULT_OLLAMA_MODEL_SIZE, + DEFAULT_OLLAMA_CREATED_AT, + DEFAULT_OLLAMA_DIGEST, +) + +# use the .env that is inside the current folder +# allows to use different .env file for each lightrag instance +# the OS environment variables take precedence over the .env file +load_dotenv(dotenv_path=".env", override=False) + + +class OllamaServerInfos: + def __init__(self, name=None, tag=None): + self._lightrag_name = name or os.getenv( + "OLLAMA_EMULATING_MODEL_NAME", DEFAULT_OLLAMA_MODEL_NAME + ) + self._lightrag_tag = tag or os.getenv( + "OLLAMA_EMULATING_MODEL_TAG", DEFAULT_OLLAMA_MODEL_TAG + ) + self.LIGHTRAG_SIZE = DEFAULT_OLLAMA_MODEL_SIZE + self.LIGHTRAG_CREATED_AT = DEFAULT_OLLAMA_CREATED_AT + self.LIGHTRAG_DIGEST = DEFAULT_OLLAMA_DIGEST + + @property + def LIGHTRAG_NAME(self): + return self._lightrag_name + + @LIGHTRAG_NAME.setter + def LIGHTRAG_NAME(self, value): + self._lightrag_name = value + + @property + def LIGHTRAG_TAG(self): + return self._lightrag_tag + + @LIGHTRAG_TAG.setter + def LIGHTRAG_TAG(self, value): + self._lightrag_tag = value + + @property + def LIGHTRAG_MODEL(self): + return f"{self._lightrag_name}:{self._lightrag_tag}" + + +class TextChunkSchema(TypedDict): + tokens: int + content: str + full_doc_id: str + chunk_order_index: int + + +T = TypeVar("T") + + +@dataclass +class QueryParam: + """Configuration parameters for query execution in LightRAG.""" + + mode: Literal["local", "global", "hybrid", "naive", "mix", "bypass"] = "mix" + """Specifies the retrieval mode: + - "local": Focuses on context-dependent information. + - "global": Utilizes global knowledge. + - "hybrid": Combines local and global retrieval methods. + - "naive": Performs a basic search without advanced techniques. + - "mix": Integrates knowledge graph and vector retrieval. + """ + + only_need_context: bool = False + """If True, only returns the retrieved context without generating a response.""" + + only_need_prompt: bool = False + """If True, only returns the generated prompt without producing a response.""" + + response_type: str = "Multiple Paragraphs" + """Defines the response format. Examples: 'Multiple Paragraphs', 'Single Paragraph', 'Bullet Points'.""" + + stream: bool = False + """If True, enables streaming output for real-time responses.""" + + top_k: int = int(os.getenv("TOP_K", str(DEFAULT_TOP_K))) + """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.""" + + chunk_top_k: int = int(os.getenv("CHUNK_TOP_K", str(DEFAULT_CHUNK_TOP_K))) + """Number of text chunks to retrieve initially from vector search and keep after reranking. + If None, defaults to top_k value. + """ + + max_entity_tokens: int = int( + os.getenv("MAX_ENTITY_TOKENS", str(DEFAULT_MAX_ENTITY_TOKENS)) + ) + """Maximum number of tokens allocated for entity context in unified token control system.""" + + max_relation_tokens: int = int( + os.getenv("MAX_RELATION_TOKENS", str(DEFAULT_MAX_RELATION_TOKENS)) + ) + """Maximum number of tokens allocated for relationship context in unified token control system.""" + + max_total_tokens: int = int( + os.getenv("MAX_TOTAL_TOKENS", str(DEFAULT_MAX_TOTAL_TOKENS)) + ) + """Maximum total tokens budget for the entire query context (entities + relations + chunks + system prompt).""" + + hl_keywords: list[str] = field(default_factory=list) + """List of high-level keywords to prioritize in retrieval.""" + + ll_keywords: list[str] = field(default_factory=list) + """List of low-level keywords to refine retrieval focus.""" + + # History mesages is only send to LLM for context, not used for retrieval + conversation_history: list[dict[str, str]] = field(default_factory=list) + """Stores past conversation history to maintain context. + Format: [{"role": "user/assistant", "content": "message"}]. + """ + + # TODO: deprecated. No longer used in the codebase, all conversation_history messages is send to LLM + history_turns: int = int(os.getenv("HISTORY_TURNS", str(DEFAULT_HISTORY_TURNS))) + """Number of complete conversation turns (user-assistant pairs) to consider in the response context.""" + + model_func: Callable[..., object] | None = None + """Optional override for the LLM model function to use for this specific query. + If provided, this will be used instead of the global model function. + This allows using different models for different query modes. + """ + + user_prompt: str | None = None + """User-provided prompt for the query. + Addition instructions for LLM. If provided, this will be inject into the prompt template. + It's purpose is the let user customize the way LLM generate the response. + """ + + enable_rerank: bool = os.getenv("RERANK_BY_DEFAULT", "true").lower() == "true" + """Enable reranking for retrieved text chunks. If True but no rerank model is configured, a warning will be issued. + Default is True to enable reranking when rerank model is available. + """ + + include_references: bool = False + """If True, includes reference list in the response for supported endpoints. + This parameter controls whether the API response includes a references field + containing citation information for the retrieved content. + """ + + +@dataclass +class StorageNameSpace(ABC): + namespace: str + workspace: str + global_config: dict[str, Any] + + async def initialize(self): + """Initialize the storage""" + pass + + async def finalize(self): + """Finalize the storage""" + pass + + @abstractmethod + async def index_done_callback(self) -> None: + """Commit the storage operations after indexing""" + + @abstractmethod + async def drop(self) -> dict[str, str]: + """Drop all data from storage and clean up resources + + This abstract method defines the contract for dropping all data from a storage implementation. + Each storage type must implement this method to: + 1. Clear all data from memory and/or external storage + 2. Remove any associated storage files if applicable + 3. Reset the storage to its initial state + 4. Handle cleanup of any resources + 5. Notify other processes if necessary + 6. This action should persistent the data to disk immediately. + + Returns: + dict[str, str]: Operation status and message with the following format: + { + "status": str, # "success" or "error" + "message": str # "data dropped" on success, error details on failure + } + + Implementation specific: + - On success: return {"status": "success", "message": "data dropped"} + - On failure: return {"status": "error", "message": ""} + - If not supported: return {"status": "error", "message": "unsupported"} + """ + + +@dataclass +class BaseVectorStorage(StorageNameSpace, ABC): + embedding_func: EmbeddingFunc + cosine_better_than_threshold: float = field(default=0.2) + meta_fields: set[str] = field(default_factory=set) + + def _validate_embedding_func(self): + """Validate that embedding_func is provided. + + This method should be called at the beginning of __post_init__ + in all vector storage implementations. + + Raises: + ValueError: If embedding_func is None + """ + if self.embedding_func is None: + raise ValueError( + "embedding_func is required for vector storage. " + "Please provide a valid EmbeddingFunc instance." + ) + + def _generate_collection_suffix(self) -> str | None: + """Generates collection/table suffix from embedding_func. + + Return suffix if model_name exists in embedding_func, otherwise return None. + Note: embedding_func is guaranteed to exist (validated in __post_init__). + + Returns: + str | None: Suffix string e.g. "text_embedding_3_large_3072d", or None if model_name not available + """ + import re + + # Check if model_name exists (model_name is optional in EmbeddingFunc) + model_name = getattr(self.embedding_func, "model_name", None) + if not model_name: + return None + + # embedding_dim is required in EmbeddingFunc + embedding_dim = self.embedding_func.embedding_dim + + # Generate suffix: clean model name and append dimension + safe_model_name = re.sub(r"[^a-zA-Z0-9_]", "_", model_name.lower()) + return f"{safe_model_name}_{embedding_dim}d" + + @abstractmethod + async def query( + self, query: str, top_k: int, query_embedding: list[float] = None + ) -> list[dict[str, Any]]: + """Query the vector storage and retrieve top_k results. + + Args: + query: The query string to search for + top_k: Number of top results to return + query_embedding: Optional pre-computed embedding for the query. + If provided, skips embedding computation for better performance. + """ + + @abstractmethod + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: + """Insert or update vectors in the storage. + + Importance notes for in-memory storage: + 1. Changes will be persisted to disk during the next index_done_callback + 2. Only one process should updating the storage at a time before index_done_callback, + KG-storage-log should be used to avoid data corruption + """ + + @abstractmethod + async def delete_entity(self, entity_name: str) -> None: + """Delete a single entity by its name. + + Importance notes for in-memory storage: + 1. Changes will be persisted to disk during the next index_done_callback + 2. Only one process should updating the storage at a time before index_done_callback, + KG-storage-log should be used to avoid data corruption + """ + + @abstractmethod + async def delete_entity_relation(self, entity_name: str) -> None: + """Delete relations for a given entity. + + Importance notes for in-memory storage: + 1. Changes will be persisted to disk during the next index_done_callback + 2. Only one process should updating the storage at a time before index_done_callback, + KG-storage-log should be used to avoid data corruption + """ + + @abstractmethod + async def get_by_id(self, id: str) -> dict[str, Any] | None: + """Get vector data by its ID + + Args: + id: The unique identifier of the vector + + Returns: + The vector data if found, or None if not found + """ + pass + + @abstractmethod + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + """Get multiple vector data by their IDs + + Args: + ids: List of unique identifiers + + Returns: + List of vector data objects that were found + """ + pass + + @abstractmethod + async def delete(self, ids: list[str]): + """Delete vectors with specified IDs + + Importance notes for in-memory storage: + 1. Changes will be persisted to disk during the next index_done_callback + 2. Only one process should updating the storage at a time before index_done_callback, + KG-storage-log should be used to avoid data corruption + + Args: + ids: List of vector IDs to be deleted + """ + + @abstractmethod + async def get_vectors_by_ids(self, ids: list[str]) -> dict[str, list[float]]: + """Get vectors by their IDs, returning only ID and vector data for efficiency + + Args: + ids: List of unique identifiers + + Returns: + Dictionary mapping IDs to their vector embeddings + Format: {id: [vector_values], ...} + """ + pass + + +@dataclass +class BaseKVStorage(StorageNameSpace, ABC): + embedding_func: EmbeddingFunc + + @abstractmethod + async def get_by_id(self, id: str) -> dict[str, Any] | None: + """Get value by id""" + + @abstractmethod + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + """Get values by ids""" + + @abstractmethod + async def filter_keys(self, keys: set[str]) -> set[str]: + """Return un-exist keys""" + + @abstractmethod + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: + """Upsert data + + Importance notes for in-memory storage: + 1. Changes will be persisted to disk during the next index_done_callback + 2. update flags to notify other processes that data persistence is needed + """ + + @abstractmethod + async def delete(self, ids: list[str]) -> None: + """Delete specific records from storage by their IDs + + Importance notes for in-memory storage: + 1. Changes will be persisted to disk during the next index_done_callback + 2. update flags to notify other processes that data persistence is needed + + Args: + ids (list[str]): List of document IDs to be deleted from storage + + Returns: + None + """ + + @abstractmethod + async def is_empty(self) -> bool: + """Check if the storage is empty + + Returns: + bool: True if storage contains no data, False otherwise + """ + + +@dataclass +class BaseGraphStorage(StorageNameSpace, ABC): + """All operations related to edges in graph should be undirected.""" + + embedding_func: EmbeddingFunc + + @abstractmethod + async def has_node(self, node_id: str) -> bool: + """Check if a node exists in the graph. + + Args: + node_id: The ID of the node to check + + Returns: + True if the node exists, False otherwise + """ + + @abstractmethod + async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: + """Check if an edge exists between two nodes. + + Args: + source_node_id: The ID of the source node + target_node_id: The ID of the target node + + Returns: + True if the edge exists, False otherwise + """ + + @abstractmethod + async def node_degree(self, node_id: str) -> int: + """Get the degree (number of connected edges) of a node. + + Args: + node_id: The ID of the node + + Returns: + The number of edges connected to the node + """ + + @abstractmethod + async def edge_degree(self, src_id: str, tgt_id: str) -> int: + """Get the total degree of an edge (sum of degrees of its source and target nodes). + + Args: + src_id: The ID of the source node + tgt_id: The ID of the target node + + Returns: + The sum of the degrees of the source and target nodes + """ + + @abstractmethod + async def get_node(self, node_id: str) -> dict[str, str] | None: + """Get node by its ID, returning only node properties. + + Args: + node_id: The ID of the node to retrieve + + Returns: + A dictionary of node properties if found, None otherwise + """ + + @abstractmethod + async def get_edge( + self, source_node_id: str, target_node_id: str + ) -> dict[str, str] | None: + """Get edge properties between two nodes. + + Args: + source_node_id: The ID of the source node + target_node_id: The ID of the target node + + Returns: + A dictionary of edge properties if found, None otherwise + """ + + @abstractmethod + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: + """Get all edges connected to a node. + + Args: + source_node_id: The ID of the node to get edges for + + Returns: + A list of (source_id, target_id) tuples representing edges, + or None if the node doesn't exist + """ + + async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: + """Get nodes as a batch using UNWIND + + Default implementation fetches nodes one by one. + Override this method for better performance in storage backends + that support batch operations. + """ + result = {} + for node_id in node_ids: + node = await self.get_node(node_id) + if node is not None: + result[node_id] = node + return result + + async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: + """Node degrees as a batch using UNWIND + + Default implementation fetches node degrees one by one. + Override this method for better performance in storage backends + that support batch operations. + """ + result = {} + for node_id in node_ids: + degree = await self.node_degree(node_id) + result[node_id] = degree + return result + + async def edge_degrees_batch( + self, edge_pairs: list[tuple[str, str]] + ) -> dict[tuple[str, str], int]: + """Edge degrees as a batch using UNWIND also uses node_degrees_batch + + Default implementation calculates edge degrees one by one. + Override this method for better performance in storage backends + that support batch operations. + """ + result = {} + for src_id, tgt_id in edge_pairs: + degree = await self.edge_degree(src_id, tgt_id) + result[(src_id, tgt_id)] = degree + return result + + async def get_edges_batch( + self, pairs: list[dict[str, str]] + ) -> dict[tuple[str, str], dict]: + """Get edges as a batch using UNWIND + + Default implementation fetches edges one by one. + Override this method for better performance in storage backends + that support batch operations. + """ + result = {} + for pair in pairs: + src_id = pair["src"] + tgt_id = pair["tgt"] + edge = await self.get_edge(src_id, tgt_id) + if edge is not None: + result[(src_id, tgt_id)] = edge + return result + + async def get_nodes_edges_batch( + self, node_ids: list[str] + ) -> dict[str, list[tuple[str, str]]]: + """Get nodes edges as a batch using UNWIND + + Default implementation fetches node edges one by one. + Override this method for better performance in storage backends + that support batch operations. + """ + result = {} + for node_id in node_ids: + edges = await self.get_node_edges(node_id) + result[node_id] = edges if edges is not None else [] + return result + + @abstractmethod + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: + """Insert a new node or update an existing node in the graph. + + Importance notes for in-memory storage: + 1. Changes will be persisted to disk during the next index_done_callback + 2. Only one process should updating the storage at a time before index_done_callback, + KG-storage-log should be used to avoid data corruption + + Args: + node_id: The ID of the node to insert or update + node_data: A dictionary of node properties + """ + + @abstractmethod + async def upsert_edge( + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ) -> None: + """Insert a new edge or update an existing edge in the graph. + + Importance notes for in-memory storage: + 1. Changes will be persisted to disk during the next index_done_callback + 2. Only one process should updating the storage at a time before index_done_callback, + KG-storage-log should be used to avoid data corruption + + Args: + source_node_id: The ID of the source node + target_node_id: The ID of the target node + edge_data: A dictionary of edge properties + """ + + @abstractmethod + async def delete_node(self, node_id: str) -> None: + """Delete a node from the graph. + + Importance notes for in-memory storage: + 1. Changes will be persisted to disk during the next index_done_callback + 2. Only one process should updating the storage at a time before index_done_callback, + KG-storage-log should be used to avoid data corruption + + Args: + node_id: The ID of the node to delete + """ + + @abstractmethod + async def remove_nodes(self, nodes: list[str]): + """Delete multiple nodes + + Importance notes: + 1. Changes will be persisted to disk during the next index_done_callback + 2. Only one process should updating the storage at a time before index_done_callback, + KG-storage-log should be used to avoid data corruption + + Args: + nodes: List of node IDs to be deleted + """ + + @abstractmethod + async def remove_edges(self, edges: list[tuple[str, str]]): + """Delete multiple edges + + Importance notes: + 1. Changes will be persisted to disk during the next index_done_callback + 2. Only one process should updating the storage at a time before index_done_callback, + KG-storage-log should be used to avoid data corruption + + Args: + edges: List of edges to be deleted, each edge is a (source, target) tuple + """ + + @abstractmethod + async def get_all_labels(self) -> list[str]: + """Get all labels(entity names) in the graph. + Do not use this method for large graph, use get_popular_labels or search_labels instead. + + Returns: + A list of all node labels in the graph, sorted alphabetically + """ + + @abstractmethod + async def get_knowledge_graph( + self, node_label: str, max_depth: int = 3, max_nodes: int = 1000 + ) -> KnowledgeGraph: + """ + Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. + + Args: + node_label: Label(entity name) of the starting node,* means all nodes + max_depth: Maximum depth of the subgraph, Defaults to 3 + max_nodes: Maxiumu nodes to return, Defaults to 1000(BFS if possible) + + Returns: + KnowledgeGraph object containing nodes and edges, with an is_truncated flag + indicating whether the graph was truncated due to max_nodes limit + """ + + @abstractmethod + async def get_all_nodes(self) -> list[dict]: + """Get all nodes in the graph. + + Returns: + A list of all nodes, where each node is a dictionary of its properties + (Edge is bidirectional for some storage implementation; deduplication must be handled by the caller) + """ + + @abstractmethod + async def get_all_edges(self) -> list[dict]: + """Get all edges in the graph. + + Returns: + A list of all edges, where each edge is a dictionary of its properties + """ + + @abstractmethod + async def get_popular_labels(self, limit: int = 300) -> list[str]: + """Get popular labels(entity names) by node degree (most connected entities) + + Args: + limit: Maximum number of labels to return + + Returns: + List of labels sorted by degree (highest first) + """ + + @abstractmethod + async def search_labels(self, query: str, limit: int = 50) -> list[str]: + """Search labels(entity names) with fuzzy matching + + Args: + query: Search query string + limit: Maximum number of results to return + + Returns: + List of matching labels sorted by relevance + """ + + +class DocStatus(str, Enum): + """Document processing status""" + + PENDING = "pending" + PROCESSING = "processing" + PREPROCESSED = "preprocessed" + PROCESSED = "processed" + FAILED = "failed" + + +@dataclass +class DocProcessingStatus: + """Document processing status data structure""" + + content_summary: str + """First 100 chars of document content, used for preview""" + content_length: int + """Total length of document""" + file_path: str + """File path of the document""" + status: DocStatus + """Current processing status""" + created_at: str + """ISO format timestamp when document was created""" + updated_at: str + """ISO format timestamp when document was last updated""" + track_id: str | None = None + """Tracking ID for monitoring progress""" + chunks_count: int | None = None + """Number of chunks after splitting, used for processing""" + chunks_list: list[str] | None = field(default_factory=list) + """List of chunk IDs associated with this document, used for deletion""" + error_msg: str | None = None + """Error message if failed""" + metadata: dict[str, Any] = field(default_factory=dict) + """Additional metadata""" + multimodal_processed: bool | None = field(default=None, repr=False) + """Internal field: indicates if multimodal processing is complete. Not shown in repr() but accessible for debugging.""" + + def __post_init__(self): + """ + Handle status conversion based on multimodal_processed field. + + Business rules: + - If multimodal_processed is False and status is PROCESSED, + then change status to PREPROCESSED + - The multimodal_processed field is kept (with repr=False) for internal use and debugging + """ + # Apply status conversion logic + if self.multimodal_processed is not None: + if ( + self.multimodal_processed is False + and self.status == DocStatus.PROCESSED + ): + self.status = DocStatus.PREPROCESSED + + +@dataclass +class DocStatusStorage(BaseKVStorage, ABC): + """Base class for document status storage""" + + @abstractmethod + async def get_status_counts(self) -> dict[str, int]: + """Get counts of documents in each status""" + + @abstractmethod + async def get_docs_by_status( + self, status: DocStatus + ) -> dict[str, DocProcessingStatus]: + """Get all documents with a specific status""" + + @abstractmethod + async def get_docs_by_track_id( + self, track_id: str + ) -> dict[str, DocProcessingStatus]: + """Get all documents with a specific track_id""" + + @abstractmethod + async def get_docs_paginated( + self, + status_filter: DocStatus | None = None, + page: int = 1, + page_size: int = 50, + sort_field: str = "updated_at", + sort_direction: str = "desc", + ) -> tuple[list[tuple[str, DocProcessingStatus]], int]: + """Get documents with pagination support + + Args: + status_filter: Filter by document status, None for all statuses + page: Page number (1-based) + page_size: Number of documents per page (10-200) + sort_field: Field to sort by ('created_at', 'updated_at', 'id') + sort_direction: Sort direction ('asc' or 'desc') + + Returns: + Tuple of (list of (doc_id, DocProcessingStatus) tuples, total_count) + """ + + @abstractmethod + async def get_all_status_counts(self) -> dict[str, int]: + """Get counts of documents in each status for all documents + + Returns: + Dictionary mapping status names to counts + """ + + @abstractmethod + async def get_doc_by_file_path(self, file_path: str) -> dict[str, Any] | None: + """Get document by file path + + Args: + file_path: The file path to search for + + Returns: + dict[str, Any] | None: Document data if found, None otherwise + Returns the same format as get_by_ids method + """ + + +class StoragesStatus(str, Enum): + """Storages status""" + + NOT_CREATED = "not_created" + CREATED = "created" + INITIALIZED = "initialized" + FINALIZED = "finalized" + + +@dataclass +class DeletionResult: + """Represents the result of a deletion operation.""" + + status: Literal["success", "not_found", "fail"] + doc_id: str + message: str + status_code: int = 200 + file_path: str | None = None + + +# Unified Query Result Data Structures for Reference List Support + + +@dataclass +class QueryResult: + """ + Unified query result data structure for all query modes. + + Attributes: + content: Text content for non-streaming responses + response_iterator: Streaming response iterator for streaming responses + raw_data: Complete structured data including references and metadata + is_streaming: Whether this is a streaming result + """ + + content: Optional[str] = None + response_iterator: Optional[AsyncIterator[str]] = None + raw_data: Optional[Dict[str, Any]] = None + is_streaming: bool = False + + @property + def reference_list(self) -> List[Dict[str, str]]: + """ + Convenient property to extract reference list from raw_data. + + Returns: + List[Dict[str, str]]: Reference list in format: + [{"reference_id": "1", "file_path": "/path/to/file.pdf"}, ...] + """ + if self.raw_data: + return self.raw_data.get("data", {}).get("references", []) + return [] + + @property + def metadata(self) -> Dict[str, Any]: + """ + Convenient property to extract metadata from raw_data. + + Returns: + Dict[str, Any]: Query metadata including query_mode, keywords, etc. + """ + if self.raw_data: + return self.raw_data.get("metadata", {}) + return {} + + +@dataclass +class QueryContextResult: + """ + Unified query context result data structure. + + Attributes: + context: LLM context string + raw_data: Complete structured data including reference_list + """ + + context: str + raw_data: Dict[str, Any] + + @property + def reference_list(self) -> List[Dict[str, str]]: + """Convenient property to extract reference list from raw_data.""" + return self.raw_data.get("data", {}).get("references", []) diff --git a/lightrag/kg/shared_storage.py b/lightrag/kg/shared_storage.py index 6da563088e..73194251a6 100644 --- a/lightrag/kg/shared_storage.py +++ b/lightrag/kg/shared_storage.py @@ -1,3 +1,8 @@ +# This script has been modified by JRS to create an LightRAG API end point for updating the knowledge graph. +# This refresh is required in order to see changes after merges, updates, adds, and deletes to the index. +# This was the function added: async def set_all_update_flags_for_all_namespaces(): +# The function works with changes in lightrag/api/routers/graph_routes.py + import os import sys import asyncio @@ -1357,6 +1362,35 @@ async def set_all_update_flags(namespace: str, workspace: str | None = None): _update_flags[final_namespace][i].value = True +# Start of: Added by JRS +async def set_all_update_flags_for_all_namespaces(): + """ + Set update flags for all existing namespaces, indicating that all workers + need to reload data from files for these namespaces. + """ + global _update_flags + if _update_flags is None: + raise ValueError("Shared-Data is not initialized.") + + async with get_internal_lock(): + if not _update_flags: + direct_log("No namespaces found to set update flags for.") + return + + for namespace in list( + _update_flags.keys() + ): # Iterate over a copy of keys as _update_flags might change + direct_log(f"Setting update flags for namespace: {namespace}") + # Ensure the namespace exists before trying to iterate its flags + if namespace in _update_flags: + for i in range(len(_update_flags[namespace])): + _update_flags[namespace][i].value = True + direct_log("All namespaces' update flags set to True.") + + +# End of: Added by JRS + + async def clear_all_update_flags(namespace: str, workspace: str | None = None): """Clear all update flag of namespace indicating all workers need to reload data from files""" global _update_flags diff --git a/lightrag/kg/shared_storage_original.py b/lightrag/kg/shared_storage_original.py new file mode 100644 index 0000000000..6da563088e --- /dev/null +++ b/lightrag/kg/shared_storage_original.py @@ -0,0 +1,1717 @@ +import os +import sys +import asyncio +import multiprocessing as mp +from multiprocessing.synchronize import Lock as ProcessLock +from multiprocessing import Manager +import time +import logging +from contextvars import ContextVar +from typing import Any, Dict, List, Optional, Union, TypeVar, Generic + +from lightrag.exceptions import PipelineNotInitializedError + +DEBUG_LOCKS = False + + +# Define a direct print function for critical logs that must be visible in all processes +def direct_log(message, enable_output: bool = True, level: str = "DEBUG"): + """ + Log a message directly to stderr to ensure visibility in all processes, + including the Gunicorn master process. + + Args: + message: The message to log + level: Log level for message (control the visibility of the message by comparing with the current logger level) + enable_output: Enable or disable log message (Force to turn off the message,) + """ + if not enable_output: + return + + # Get the current logger level from the lightrag logger + try: + from lightrag.utils import logger + + current_level = logger.getEffectiveLevel() + except ImportError: + # Fallback if lightrag.utils is not available + current_level = 20 # INFO + + # Convert string level to numeric level for comparison + level_mapping = { + "DEBUG": 10, # DEBUG + "INFO": 20, # INFO + "WARNING": 30, # WARNING + "ERROR": 40, # ERROR + "CRITICAL": 50, # CRITICAL + } + message_level = level_mapping.get(level.upper(), logging.DEBUG) + + if message_level >= current_level: + print(f"{level}: {message}", file=sys.stderr, flush=True) + + +T = TypeVar("T") +LockType = Union[ProcessLock, asyncio.Lock] + +_is_multiprocess = None +_workers = None +_manager = None + +# Global singleton data for multi-process keyed locks +_lock_registry: Optional[Dict[str, mp.synchronize.Lock]] = None +_lock_registry_count: Optional[Dict[str, int]] = None +_lock_cleanup_data: Optional[Dict[str, time.time]] = None +_registry_guard = None +# Timeout for keyed locks in seconds (Default 300) +CLEANUP_KEYED_LOCKS_AFTER_SECONDS = 300 +# Cleanup pending list threshold for triggering cleanup (Default 500) +CLEANUP_THRESHOLD = 500 +# Minimum interval between cleanup operations in seconds (Default 30) +MIN_CLEANUP_INTERVAL_SECONDS = 30 +# Track the earliest cleanup time for efficient cleanup triggering (multiprocess locks only) +_earliest_mp_cleanup_time: Optional[float] = None +# Track the last cleanup time to enforce minimum interval (multiprocess locks only) +_last_mp_cleanup_time: Optional[float] = None + +_initialized = None + +# Default workspace for backward compatibility +_default_workspace: Optional[str] = None + +# shared data for storage across processes +_shared_dicts: Optional[Dict[str, Any]] = None +_init_flags: Optional[Dict[str, bool]] = None # namespace -> initialized +_update_flags: Optional[Dict[str, bool]] = None # namespace -> updated + +# locks for mutex access +_internal_lock: Optional[LockType] = None +_data_init_lock: Optional[LockType] = None +# Manager for all keyed locks +_storage_keyed_lock: Optional["KeyedUnifiedLock"] = None + +# async locks for coroutine synchronization in multiprocess mode +_async_locks: Optional[Dict[str, asyncio.Lock]] = None + +_debug_n_locks_acquired: int = 0 + + +def get_final_namespace(namespace: str, workspace: str | None = None): + global _default_workspace + if workspace is None: + workspace = _default_workspace + + if workspace is None: + direct_log( + f"Error: Invoke namespace operation without workspace, pid={os.getpid()}", + level="ERROR", + ) + raise ValueError("Invoke namespace operation without workspace") + + final_namespace = f"{workspace}:{namespace}" if workspace else f"{namespace}" + return final_namespace + + +def inc_debug_n_locks_acquired(): + global _debug_n_locks_acquired + if DEBUG_LOCKS: + _debug_n_locks_acquired += 1 + print(f"DEBUG: Keyed Lock acquired, total: {_debug_n_locks_acquired:>5}") + + +def dec_debug_n_locks_acquired(): + global _debug_n_locks_acquired + if DEBUG_LOCKS: + if _debug_n_locks_acquired > 0: + _debug_n_locks_acquired -= 1 + print(f"DEBUG: Keyed Lock released, total: {_debug_n_locks_acquired:>5}") + else: + raise RuntimeError("Attempting to release lock when no locks are acquired") + + +def get_debug_n_locks_acquired(): + global _debug_n_locks_acquired + return _debug_n_locks_acquired + + +class UnifiedLock(Generic[T]): + """Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock""" + + def __init__( + self, + lock: Union[ProcessLock, asyncio.Lock], + is_async: bool, + name: str = "unnamed", + enable_logging: bool = True, + async_lock: Optional[asyncio.Lock] = None, + ): + self._lock = lock + self._is_async = is_async + self._pid = os.getpid() # for debug only + self._name = name # for debug only + self._enable_logging = enable_logging # for debug only + self._async_lock = async_lock # auxiliary lock for coroutine synchronization + + async def __aenter__(self) -> "UnifiedLock[T]": + try: + # If in multiprocess mode and async lock exists, acquire it first + if not self._is_async and self._async_lock is not None: + await self._async_lock.acquire() + direct_log( + f"== Lock == Process {self._pid}: Acquired async lock '{self._name}", + level="DEBUG", + enable_output=self._enable_logging, + ) + + # Acquire the main lock + # Note: self._lock should never be None here as the check has been moved + # to get_internal_lock() and get_data_init_lock() functions + if self._is_async: + await self._lock.acquire() + else: + self._lock.acquire() + + direct_log( + f"== Lock == Process {self._pid}: Acquired lock {self._name} (async={self._is_async})", + level="INFO", + enable_output=self._enable_logging, + ) + return self + except Exception as e: + # If main lock acquisition fails, release the async lock if it was acquired + if ( + not self._is_async + and self._async_lock is not None + and self._async_lock.locked() + ): + self._async_lock.release() + + direct_log( + f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}': {e}", + level="ERROR", + enable_output=True, + ) + raise + + async def __aexit__(self, exc_type, exc_val, exc_tb): + main_lock_released = False + async_lock_released = False + try: + # Release main lock first + if self._lock is not None: + if self._is_async: + self._lock.release() + else: + self._lock.release() + + direct_log( + f"== Lock == Process {self._pid}: Released lock {self._name} (async={self._is_async})", + level="INFO", + enable_output=self._enable_logging, + ) + main_lock_released = True + + # Then release async lock if in multiprocess mode + if not self._is_async and self._async_lock is not None: + self._async_lock.release() + direct_log( + f"== Lock == Process {self._pid}: Released async lock {self._name}", + level="DEBUG", + enable_output=self._enable_logging, + ) + async_lock_released = True + + except Exception as e: + direct_log( + f"== Lock == Process {self._pid}: Failed to release lock '{self._name}': {e}", + level="ERROR", + enable_output=True, + ) + + # If main lock release failed but async lock hasn't been attempted yet, try to release it + if ( + not main_lock_released + and not async_lock_released + and not self._is_async + and self._async_lock is not None + ): + try: + direct_log( + f"== Lock == Process {self._pid}: Attempting to release async lock after main lock failure", + level="DEBUG", + enable_output=self._enable_logging, + ) + self._async_lock.release() + direct_log( + f"== Lock == Process {self._pid}: Successfully released async lock after main lock failure", + level="INFO", + enable_output=self._enable_logging, + ) + except Exception as inner_e: + direct_log( + f"== Lock == Process {self._pid}: Failed to release async lock after main lock failure: {inner_e}", + level="ERROR", + enable_output=True, + ) + + raise + + def __enter__(self) -> "UnifiedLock[T]": + """For backward compatibility""" + try: + if self._is_async: + raise RuntimeError("Use 'async with' for shared_storage lock") + + # Acquire the main lock + # Note: self._lock should never be None here as the check has been moved + # to get_internal_lock() and get_data_init_lock() functions + direct_log( + f"== Lock == Process {self._pid}: Acquiring lock {self._name} (sync)", + level="DEBUG", + enable_output=self._enable_logging, + ) + self._lock.acquire() + direct_log( + f"== Lock == Process {self._pid}: Acquired lock {self._name} (sync)", + level="INFO", + enable_output=self._enable_logging, + ) + return self + except Exception as e: + direct_log( + f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}' (sync): {e}", + level="ERROR", + enable_output=True, + ) + raise + + def __exit__(self, exc_type, exc_val, exc_tb): + """For backward compatibility""" + try: + if self._is_async: + raise RuntimeError("Use 'async with' for shared_storage lock") + direct_log( + f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (sync)", + level="DEBUG", + enable_output=self._enable_logging, + ) + self._lock.release() + direct_log( + f"== Lock == Process {self._pid}: Released lock {self._name} (sync)", + level="INFO", + enable_output=self._enable_logging, + ) + except Exception as e: + direct_log( + f"== Lock == Process {self._pid}: Failed to release lock '{self._name}' (sync): {e}", + level="ERROR", + enable_output=True, + ) + raise + + def locked(self) -> bool: + if self._is_async: + return self._lock.locked() + else: + return self._lock.locked() + + +def _get_combined_key(factory_name: str, key: str) -> str: + """Return the combined key for the factory and key.""" + return f"{factory_name}:{key}" + + +def _perform_lock_cleanup( + lock_type: str, + cleanup_data: Dict[str, float], + lock_registry: Optional[Dict[str, Any]], + lock_count: Optional[Dict[str, int]], + earliest_cleanup_time: Optional[float], + last_cleanup_time: Optional[float], + current_time: float, + threshold_check: bool = True, +) -> tuple[int, Optional[float], Optional[float]]: + """ + Generic lock cleanup function to unify cleanup logic for both multiprocess and async locks. + + Args: + lock_type: Lock type identifier ("mp" or "async") + cleanup_data: Cleanup data dictionary + lock_registry: Lock registry dictionary (can be None for async locks) + lock_count: Lock count dictionary (can be None for async locks) + earliest_cleanup_time: Earliest cleanup time + last_cleanup_time: Last cleanup time + current_time: Current time + threshold_check: Whether to check threshold condition (default True, set to False in cleanup_expired_locks) + + Returns: + tuple: (cleaned_count, new_earliest_time, new_last_cleanup_time) + """ + if len(cleanup_data) == 0: + return 0, earliest_cleanup_time, last_cleanup_time + + # If threshold check is needed and threshold not reached, return directly + if threshold_check and len(cleanup_data) < CLEANUP_THRESHOLD: + return 0, earliest_cleanup_time, last_cleanup_time + + # Time rollback detection + if last_cleanup_time is not None and current_time < last_cleanup_time: + direct_log( + f"== {lock_type} Lock == Time rollback detected, resetting cleanup time", + level="WARNING", + enable_output=False, + ) + last_cleanup_time = None + + # Check cleanup conditions + has_expired_locks = ( + earliest_cleanup_time is not None + and current_time - earliest_cleanup_time > CLEANUP_KEYED_LOCKS_AFTER_SECONDS + ) + + interval_satisfied = ( + last_cleanup_time is None + or current_time - last_cleanup_time > MIN_CLEANUP_INTERVAL_SECONDS + ) + + if not (has_expired_locks and interval_satisfied): + return 0, earliest_cleanup_time, last_cleanup_time + + try: + cleaned_count = 0 + new_earliest_time = None + + # Calculate total count before cleanup + total_cleanup_len = len(cleanup_data) + + # Perform cleanup operation + for cleanup_key, cleanup_time in list(cleanup_data.items()): + if current_time - cleanup_time > CLEANUP_KEYED_LOCKS_AFTER_SECONDS: + # Remove from cleanup data + cleanup_data.pop(cleanup_key, None) + + # Remove from lock registry if exists + if lock_registry is not None: + lock_registry.pop(cleanup_key, None) + if lock_count is not None: + lock_count.pop(cleanup_key, None) + + cleaned_count += 1 + else: + # Track the earliest time among remaining locks + if new_earliest_time is None or cleanup_time < new_earliest_time: + new_earliest_time = cleanup_time + + # Update state only after successful cleanup + if cleaned_count > 0: + new_last_cleanup_time = current_time + + # Log cleanup results + next_cleanup_in = max( + (new_earliest_time + CLEANUP_KEYED_LOCKS_AFTER_SECONDS - current_time) + if new_earliest_time + else float("inf"), + MIN_CLEANUP_INTERVAL_SECONDS, + ) + + if lock_type == "async": + direct_log( + f"== {lock_type} Lock == Cleaned up {cleaned_count}/{total_cleanup_len} expired {lock_type} locks, " + f"next cleanup in {next_cleanup_in:.1f}s", + enable_output=False, + level="INFO", + ) + else: + direct_log( + f"== {lock_type} Lock == Cleaned up {cleaned_count}/{total_cleanup_len} expired locks, " + f"next cleanup in {next_cleanup_in:.1f}s", + enable_output=False, + level="INFO", + ) + + return cleaned_count, new_earliest_time, new_last_cleanup_time + else: + return 0, earliest_cleanup_time, last_cleanup_time + + except Exception as e: + direct_log( + f"== {lock_type} Lock == Cleanup failed: {e}", + level="ERROR", + enable_output=True, + ) + return 0, earliest_cleanup_time, last_cleanup_time + + +def _get_or_create_shared_raw_mp_lock( + factory_name: str, key: str +) -> Optional[mp.synchronize.Lock]: + """Return the *singleton* manager.Lock() proxy for keyed lock, creating if needed.""" + if not _is_multiprocess: + return None + + with _registry_guard: + combined_key = _get_combined_key(factory_name, key) + raw = _lock_registry.get(combined_key) + count = _lock_registry_count.get(combined_key) + if raw is None: + raw = _manager.Lock() + _lock_registry[combined_key] = raw + count = 0 + else: + if count is None: + raise RuntimeError( + f"Shared-Data lock registry for {factory_name} is corrupted for key {key}" + ) + if ( + count == 0 and combined_key in _lock_cleanup_data + ): # Reusing an key waiting for cleanup, remove it from cleanup list + _lock_cleanup_data.pop(combined_key) + count += 1 + _lock_registry_count[combined_key] = count + return raw + + +def _release_shared_raw_mp_lock(factory_name: str, key: str): + """Release the *singleton* manager.Lock() proxy for *key*.""" + if not _is_multiprocess: + return + + global _earliest_mp_cleanup_time, _last_mp_cleanup_time + + with _registry_guard: + combined_key = _get_combined_key(factory_name, key) + raw = _lock_registry.get(combined_key) + count = _lock_registry_count.get(combined_key) + if raw is None and count is None: + return + elif raw is None or count is None: + raise RuntimeError( + f"Shared-Data lock registry for {factory_name} is corrupted for key {key}" + ) + + count -= 1 + if count < 0: + raise RuntimeError( + f"Attempting to release lock for {key} more times than it was acquired" + ) + + _lock_registry_count[combined_key] = count + + current_time = time.time() + if count == 0: + _lock_cleanup_data[combined_key] = current_time + + # Update earliest multiprocess cleanup time (only when earlier) + if ( + _earliest_mp_cleanup_time is None + or current_time < _earliest_mp_cleanup_time + ): + _earliest_mp_cleanup_time = current_time + + # Use generic cleanup function + cleaned_count, new_earliest_time, new_last_cleanup_time = _perform_lock_cleanup( + lock_type="mp", + cleanup_data=_lock_cleanup_data, + lock_registry=_lock_registry, + lock_count=_lock_registry_count, + earliest_cleanup_time=_earliest_mp_cleanup_time, + last_cleanup_time=_last_mp_cleanup_time, + current_time=current_time, + threshold_check=True, + ) + + # Update global state if cleanup was performed + if cleaned_count > 0: + _earliest_mp_cleanup_time = new_earliest_time + _last_mp_cleanup_time = new_last_cleanup_time + + +class KeyedUnifiedLock: + """ + Manager for unified keyed locks, supporting both single and multi-process + + • Keeps only a table of async keyed locks locally + • Fetches the multi-process keyed lock on every acquire + • Builds a fresh `UnifiedLock` each time, so `enable_logging` + (or future options) can vary per call. + • Supports dynamic namespaces specified at lock usage time + """ + + def __init__(self, *, default_enable_logging: bool = True) -> None: + self._default_enable_logging = default_enable_logging + self._async_lock: Dict[str, asyncio.Lock] = {} # local keyed locks + self._async_lock_count: Dict[ + str, int + ] = {} # local keyed locks referenced count + self._async_lock_cleanup_data: Dict[ + str, time.time + ] = {} # local keyed locks timeout + self._mp_locks: Dict[ + str, mp.synchronize.Lock + ] = {} # multi-process lock proxies + self._earliest_async_cleanup_time: Optional[float] = ( + None # track earliest async cleanup time + ) + self._last_async_cleanup_time: Optional[float] = ( + None # track last async cleanup time for minimum interval + ) + + def __call__( + self, namespace: str, keys: list[str], *, enable_logging: Optional[bool] = None + ): + """ + Ergonomic helper so you can write: + + async with storage_keyed_lock("namespace", ["key1", "key2"]): + ... + """ + if enable_logging is None: + enable_logging = self._default_enable_logging + return _KeyedLockContext( + self, + namespace=namespace, + keys=keys, + enable_logging=enable_logging, + ) + + def _get_or_create_async_lock(self, combined_key: str) -> asyncio.Lock: + async_lock = self._async_lock.get(combined_key) + count = self._async_lock_count.get(combined_key, 0) + if async_lock is None: + async_lock = asyncio.Lock() + self._async_lock[combined_key] = async_lock + elif count == 0 and combined_key in self._async_lock_cleanup_data: + self._async_lock_cleanup_data.pop(combined_key) + count += 1 + self._async_lock_count[combined_key] = count + return async_lock + + def _release_async_lock(self, combined_key: str): + count = self._async_lock_count.get(combined_key, 0) + count -= 1 + + current_time = time.time() + if count == 0: + self._async_lock_cleanup_data[combined_key] = current_time + + # Update earliest async cleanup time (only when earlier) + if ( + self._earliest_async_cleanup_time is None + or current_time < self._earliest_async_cleanup_time + ): + self._earliest_async_cleanup_time = current_time + self._async_lock_count[combined_key] = count + + # Use generic cleanup function + cleaned_count, new_earliest_time, new_last_cleanup_time = _perform_lock_cleanup( + lock_type="async", + cleanup_data=self._async_lock_cleanup_data, + lock_registry=self._async_lock, + lock_count=self._async_lock_count, + earliest_cleanup_time=self._earliest_async_cleanup_time, + last_cleanup_time=self._last_async_cleanup_time, + current_time=current_time, + threshold_check=True, + ) + + # Update instance state if cleanup was performed + if cleaned_count > 0: + self._earliest_async_cleanup_time = new_earliest_time + self._last_async_cleanup_time = new_last_cleanup_time + + def _get_lock_for_key( + self, namespace: str, key: str, enable_logging: bool = False + ) -> UnifiedLock: + # 1. Create combined key for this namespace:key combination + combined_key = _get_combined_key(namespace, key) + + # 2. get (or create) the per‑process async gate for this combined key + # Is synchronous, so no need to acquire a lock + async_lock = self._get_or_create_async_lock(combined_key) + + # 3. fetch the shared raw lock + raw_lock = _get_or_create_shared_raw_mp_lock(namespace, key) + is_multiprocess = raw_lock is not None + if not is_multiprocess: + raw_lock = async_lock + + # 4. build a *fresh* UnifiedLock with the chosen logging flag + if is_multiprocess: + return UnifiedLock( + lock=raw_lock, + is_async=False, # manager.Lock is synchronous + name=combined_key, + enable_logging=enable_logging, + async_lock=async_lock, # prevents event‑loop blocking + ) + else: + return UnifiedLock( + lock=raw_lock, + is_async=True, + name=combined_key, + enable_logging=enable_logging, + async_lock=None, # No need for async lock in single process mode + ) + + def _release_lock_for_key(self, namespace: str, key: str): + combined_key = _get_combined_key(namespace, key) + self._release_async_lock(combined_key) + _release_shared_raw_mp_lock(namespace, key) + + def cleanup_expired_locks(self) -> Dict[str, Any]: + """ + Cleanup expired locks for both async and multiprocess locks following the same + conditions as _release_shared_raw_mp_lock and _release_async_lock functions. + + Only performs cleanup when both has_expired_locks and interval_satisfied conditions are met + to avoid too frequent cleanup operations. + + Since async and multiprocess locks work together, this method cleans up + both types of expired locks and returns comprehensive statistics. + + Returns: + Dict containing cleanup statistics and current status: + { + "process_id": 12345, + "cleanup_performed": { + "mp_cleaned": 5, + "async_cleaned": 3 + }, + "current_status": { + "total_mp_locks": 10, + "pending_mp_cleanup": 2, + "total_async_locks": 8, + "pending_async_cleanup": 1 + } + } + """ + global _lock_registry, _lock_registry_count, _lock_cleanup_data + global _registry_guard, _earliest_mp_cleanup_time, _last_mp_cleanup_time + + cleanup_stats = {"mp_cleaned": 0, "async_cleaned": 0} + + current_time = time.time() + + # 1. Cleanup multiprocess locks using generic function + if ( + _is_multiprocess + and _lock_registry is not None + and _registry_guard is not None + ): + try: + with _registry_guard: + if _lock_cleanup_data is not None: + # Use generic cleanup function without threshold check + cleaned_count, new_earliest_time, new_last_cleanup_time = ( + _perform_lock_cleanup( + lock_type="mp", + cleanup_data=_lock_cleanup_data, + lock_registry=_lock_registry, + lock_count=_lock_registry_count, + earliest_cleanup_time=_earliest_mp_cleanup_time, + last_cleanup_time=_last_mp_cleanup_time, + current_time=current_time, + threshold_check=False, # Force cleanup in cleanup_expired_locks + ) + ) + + # Update global state if cleanup was performed + if cleaned_count > 0: + _earliest_mp_cleanup_time = new_earliest_time + _last_mp_cleanup_time = new_last_cleanup_time + cleanup_stats["mp_cleaned"] = cleaned_count + + except Exception as e: + direct_log( + f"Error during multiprocess lock cleanup: {e}", + level="ERROR", + enable_output=True, + ) + + # 2. Cleanup async locks using generic function + try: + # Use generic cleanup function without threshold check + cleaned_count, new_earliest_time, new_last_cleanup_time = ( + _perform_lock_cleanup( + lock_type="async", + cleanup_data=self._async_lock_cleanup_data, + lock_registry=self._async_lock, + lock_count=self._async_lock_count, + earliest_cleanup_time=self._earliest_async_cleanup_time, + last_cleanup_time=self._last_async_cleanup_time, + current_time=current_time, + threshold_check=False, # Force cleanup in cleanup_expired_locks + ) + ) + + # Update instance state if cleanup was performed + if cleaned_count > 0: + self._earliest_async_cleanup_time = new_earliest_time + self._last_async_cleanup_time = new_last_cleanup_time + cleanup_stats["async_cleaned"] = cleaned_count + + except Exception as e: + direct_log( + f"Error during async lock cleanup: {e}", + level="ERROR", + enable_output=True, + ) + + # 3. Get current status after cleanup + current_status = self.get_lock_status() + + return { + "process_id": os.getpid(), + "cleanup_performed": cleanup_stats, + "current_status": current_status, + } + + def get_lock_status(self) -> Dict[str, int]: + """ + Get current status of both async and multiprocess locks. + + Returns comprehensive lock counts for both types of locks since + they work together in the keyed lock system. + + Returns: + Dict containing lock counts: + { + "total_mp_locks": 10, + "pending_mp_cleanup": 2, + "total_async_locks": 8, + "pending_async_cleanup": 1 + } + """ + global _lock_registry_count, _lock_cleanup_data, _registry_guard + + status = { + "total_mp_locks": 0, + "pending_mp_cleanup": 0, + "total_async_locks": 0, + "pending_async_cleanup": 0, + } + + try: + # Count multiprocess locks + if _is_multiprocess and _lock_registry_count is not None: + if _registry_guard is not None: + with _registry_guard: + status["total_mp_locks"] = len(_lock_registry_count) + if _lock_cleanup_data is not None: + status["pending_mp_cleanup"] = len(_lock_cleanup_data) + + # Count async locks + status["total_async_locks"] = len(self._async_lock_count) + status["pending_async_cleanup"] = len(self._async_lock_cleanup_data) + + except Exception as e: + direct_log( + f"Error getting keyed lock status: {e}", + level="ERROR", + enable_output=True, + ) + + return status + + +class _KeyedLockContext: + def __init__( + self, + parent: KeyedUnifiedLock, + namespace: str, + keys: list[str], + enable_logging: bool, + ) -> None: + self._parent = parent + self._namespace = namespace + + # The sorting is critical to ensure proper lock and release order + # to avoid deadlocks + self._keys = sorted(keys) + self._enable_logging = ( + enable_logging + if enable_logging is not None + else parent._default_enable_logging + ) + self._ul: Optional[List[Dict[str, Any]]] = None # set in __aenter__ + + # ----- enter ----- + async def __aenter__(self): + if self._ul is not None: + raise RuntimeError("KeyedUnifiedLock already acquired in current context") + + self._ul = [] + + try: + # Acquire locks for all keys in the namespace + for key in self._keys: + lock = None + entry = None + + try: + # 1. Get lock object (reference count is incremented here) + lock = self._parent._get_lock_for_key( + self._namespace, key, enable_logging=self._enable_logging + ) + + # 2. Immediately create and add entry to list (critical for rollback to work) + entry = { + "key": key, + "lock": lock, + "entered": False, + "debug_inc": False, + "ref_incremented": True, # Mark that reference count has been incremented + } + self._ul.append( + entry + ) # Add immediately after _get_lock_for_key for rollback to work + + # 3. Try to acquire the lock + # Use try-finally to ensure state is updated atomically + lock_acquired = False + try: + await lock.__aenter__() + lock_acquired = True # Lock successfully acquired + finally: + if lock_acquired: + entry["entered"] = True + inc_debug_n_locks_acquired() + entry["debug_inc"] = True + + except asyncio.CancelledError: + # Lock acquisition was cancelled + # The finally block above ensures entry["entered"] is correct + direct_log( + f"Lock acquisition cancelled for key {key}", + level="WARNING", + enable_output=self._enable_logging, + ) + raise + except Exception as e: + # Other exceptions, log and re-raise + direct_log( + f"Lock acquisition failed for key {key}: {e}", + level="ERROR", + enable_output=True, + ) + raise + + return self + + except BaseException: + # Critical: if any exception occurs (including CancelledError) during lock acquisition, + # we must rollback all already acquired locks to prevent lock leaks + # Use shield to ensure rollback completes + await asyncio.shield(self._rollback_acquired_locks()) + raise + + async def _rollback_acquired_locks(self): + """Rollback all acquired locks in case of exception during __aenter__""" + if not self._ul: + return + + async def rollback_single_entry(entry): + """Rollback a single lock acquisition""" + key = entry["key"] + lock = entry["lock"] + debug_inc = entry["debug_inc"] + entered = entry["entered"] + ref_incremented = entry.get( + "ref_incremented", True + ) # Default to True for safety + + errors = [] + + # 1. If lock was acquired, release it + if entered: + try: + await lock.__aexit__(None, None, None) + except Exception as e: + errors.append(("lock_exit", e)) + direct_log( + f"Lock rollback error for key {key}: {e}", + level="ERROR", + enable_output=True, + ) + + # 2. Release reference count (if it was incremented) + if ref_incremented: + try: + self._parent._release_lock_for_key(self._namespace, key) + except Exception as e: + errors.append(("ref_release", e)) + direct_log( + f"Lock rollback reference release error for key {key}: {e}", + level="ERROR", + enable_output=True, + ) + + # 3. Decrement debug counter + if debug_inc: + try: + dec_debug_n_locks_acquired() + except Exception as e: + errors.append(("debug_dec", e)) + direct_log( + f"Lock rollback counter decrementing error for key {key}: {e}", + level="ERROR", + enable_output=True, + ) + + return errors + + # Release already acquired locks in reverse order + for entry in reversed(self._ul): + # Use shield to protect each lock's rollback + try: + await asyncio.shield(rollback_single_entry(entry)) + except Exception as e: + # Log but continue rolling back other locks + direct_log( + f"Lock rollback unexpected error for {entry['key']}: {e}", + level="ERROR", + enable_output=True, + ) + + self._ul = None + + # ----- exit ----- + async def __aexit__(self, exc_type, exc, tb): + if self._ul is None: + return + + async def release_all_locks(): + """Release all locks with comprehensive error handling, protected from cancellation""" + + async def release_single_entry(entry, exc_type, exc, tb): + """Release a single lock with full protection""" + key = entry["key"] + lock = entry["lock"] + debug_inc = entry["debug_inc"] + entered = entry["entered"] + + errors = [] + + # 1. Release the lock + if entered: + try: + await lock.__aexit__(exc_type, exc, tb) + except Exception as e: + errors.append(("lock_exit", e)) + direct_log( + f"Lock release error for key {key}: {e}", + level="ERROR", + enable_output=True, + ) + + # 2. Release reference count + try: + self._parent._release_lock_for_key(self._namespace, key) + except Exception as e: + errors.append(("ref_release", e)) + direct_log( + f"Lock release reference error for key {key}: {e}", + level="ERROR", + enable_output=True, + ) + + # 3. Decrement debug counter + if debug_inc: + try: + dec_debug_n_locks_acquired() + except Exception as e: + errors.append(("debug_dec", e)) + direct_log( + f"Lock release counter decrementing error for key {key}: {e}", + level="ERROR", + enable_output=True, + ) + + return errors + + all_errors = [] + + # Release locks in reverse order + # This entire loop is protected by the outer shield + for entry in reversed(self._ul): + try: + errors = await release_single_entry(entry, exc_type, exc, tb) + for error_type, error in errors: + all_errors.append((entry["key"], error_type, error)) + except Exception as e: + all_errors.append((entry["key"], "unexpected", e)) + direct_log( + f"Lock release unexpected error for {entry['key']}: {e}", + level="ERROR", + enable_output=True, + ) + + return all_errors + + # CRITICAL: Protect the entire release process with shield + # This ensures that even if cancellation occurs, all locks are released + try: + all_errors = await asyncio.shield(release_all_locks()) + except Exception as e: + direct_log( + f"Critical error during __aexit__ cleanup: {e}", + level="ERROR", + enable_output=True, + ) + all_errors = [] + finally: + # Always clear the lock list, even if shield was cancelled + self._ul = None + + # If there were release errors and no other exception, raise the first release error + if all_errors and exc_type is None: + raise all_errors[0][2] # (key, error_type, error) + + +def get_internal_lock(enable_logging: bool = False) -> UnifiedLock: + """return unified storage lock for data consistency""" + if _internal_lock is None: + raise RuntimeError( + "Shared data not initialized. Call initialize_share_data() before using locks!" + ) + async_lock = _async_locks.get("internal_lock") if _is_multiprocess else None + return UnifiedLock( + lock=_internal_lock, + is_async=not _is_multiprocess, + name="internal_lock", + enable_logging=enable_logging, + async_lock=async_lock, + ) + + +# Workspace based storage_lock is implemented by get_storage_keyed_lock instead. +# Workspace based pipeline_status_lock is implemented by get_storage_keyed_lock instead. +# No need to implement graph_db_lock: +# data integrity is ensured by entity level keyed-lock and allowing only one process to hold pipeline at a time. + + +def get_storage_keyed_lock( + keys: str | list[str], namespace: str = "default", enable_logging: bool = False +) -> _KeyedLockContext: + """Return unified storage keyed lock for ensuring atomic operations across different namespaces""" + global _storage_keyed_lock + if _storage_keyed_lock is None: + raise RuntimeError("Shared-Data is not initialized") + if isinstance(keys, str): + keys = [keys] + return _storage_keyed_lock(namespace, keys, enable_logging=enable_logging) + + +def get_data_init_lock(enable_logging: bool = False) -> UnifiedLock: + """return unified data initialization lock for ensuring atomic data initialization""" + if _data_init_lock is None: + raise RuntimeError( + "Shared data not initialized. Call initialize_share_data() before using locks!" + ) + async_lock = _async_locks.get("data_init_lock") if _is_multiprocess else None + return UnifiedLock( + lock=_data_init_lock, + is_async=not _is_multiprocess, + name="data_init_lock", + enable_logging=enable_logging, + async_lock=async_lock, + ) + + +def cleanup_keyed_lock() -> Dict[str, Any]: + """ + Force cleanup of expired keyed locks and return comprehensive status information. + + This function actively cleans up expired locks for both async and multiprocess locks, + then returns detailed statistics about the cleanup operation and current lock status. + + Returns: + Same as cleanup_expired_locks in KeyedUnifiedLock + """ + global _storage_keyed_lock + + # Check if shared storage is initialized + if not _initialized or _storage_keyed_lock is None: + return { + "process_id": os.getpid(), + "cleanup_performed": {"mp_cleaned": 0, "async_cleaned": 0}, + "current_status": { + "total_mp_locks": 0, + "pending_mp_cleanup": 0, + "total_async_locks": 0, + "pending_async_cleanup": 0, + }, + } + + return _storage_keyed_lock.cleanup_expired_locks() + + +def get_keyed_lock_status() -> Dict[str, Any]: + """ + Get current status of keyed locks without performing cleanup. + + This function provides a read-only view of the current lock counts + for both multiprocess and async locks, including pending cleanup counts. + + Returns: + Same as get_lock_status in KeyedUnifiedLock + """ + global _storage_keyed_lock + + # Check if shared storage is initialized + if not _initialized or _storage_keyed_lock is None: + return { + "process_id": os.getpid(), + "total_mp_locks": 0, + "pending_mp_cleanup": 0, + "total_async_locks": 0, + "pending_async_cleanup": 0, + } + + status = _storage_keyed_lock.get_lock_status() + status["process_id"] = os.getpid() + return status + + +def initialize_share_data(workers: int = 1): + """ + Initialize shared storage data for single or multi-process mode. + + When used with Gunicorn's preload feature, this function is called once in the + master process before forking worker processes, allowing all workers to share + the same initialized data. + + In single-process mode, this function is called in FASTAPI lifespan function. + + The function determines whether to use cross-process shared variables for data storage + based on the number of workers. If workers=1, it uses thread locks and local dictionaries. + If workers>1, it uses process locks and shared dictionaries managed by multiprocessing.Manager. + + Args: + workers (int): Number of worker processes. If 1, single-process mode is used. + If > 1, multi-process mode with shared memory is used. + """ + global \ + _manager, \ + _workers, \ + _is_multiprocess, \ + _lock_registry, \ + _lock_registry_count, \ + _lock_cleanup_data, \ + _registry_guard, \ + _internal_lock, \ + _data_init_lock, \ + _shared_dicts, \ + _init_flags, \ + _initialized, \ + _update_flags, \ + _async_locks, \ + _storage_keyed_lock, \ + _earliest_mp_cleanup_time, \ + _last_mp_cleanup_time + + # Check if already initialized + if _initialized: + direct_log( + f"Process {os.getpid()} Shared-Data already initialized (multiprocess={_is_multiprocess})" + ) + return + + _workers = workers + + if workers > 1: + _is_multiprocess = True + _manager = Manager() + _lock_registry = _manager.dict() + _lock_registry_count = _manager.dict() + _lock_cleanup_data = _manager.dict() + _registry_guard = _manager.RLock() + _internal_lock = _manager.Lock() + _data_init_lock = _manager.Lock() + _shared_dicts = _manager.dict() + _init_flags = _manager.dict() + _update_flags = _manager.dict() + + _storage_keyed_lock = KeyedUnifiedLock() + + # Initialize async locks for multiprocess mode + _async_locks = { + "internal_lock": asyncio.Lock(), + "graph_db_lock": asyncio.Lock(), + "data_init_lock": asyncio.Lock(), + } + + direct_log( + f"Process {os.getpid()} Shared-Data created for Multiple Process (workers={workers})" + ) + else: + _is_multiprocess = False + _internal_lock = asyncio.Lock() + _data_init_lock = asyncio.Lock() + _shared_dicts = {} + _init_flags = {} + _update_flags = {} + _async_locks = None # No need for async locks in single process mode + + _storage_keyed_lock = KeyedUnifiedLock() + direct_log(f"Process {os.getpid()} Shared-Data created for Single Process") + + # Initialize multiprocess cleanup times + _earliest_mp_cleanup_time = None + _last_mp_cleanup_time = None + + # Mark as initialized + _initialized = True + + +async def initialize_pipeline_status(workspace: str | None = None): + """ + Initialize pipeline_status share data with default values. + This function could be called before during FASTAPI lifespan for each worker. + + Args: + workspace: Optional workspace identifier for pipeline_status of specific workspace. + If None or empty string, uses the default workspace set by + set_default_workspace(). + """ + pipeline_namespace = await get_namespace_data( + "pipeline_status", first_init=True, workspace=workspace + ) + + async with get_internal_lock(): + # Check if already initialized by checking for required fields + if "busy" in pipeline_namespace: + return + + # Create a shared list object for history_messages + history_messages = _manager.list() if _is_multiprocess else [] + pipeline_namespace.update( + { + "autoscanned": False, # Auto-scan started + "busy": False, # Control concurrent processes + "job_name": "-", # Current job name (indexing files/indexing texts) + "job_start": None, # Job start time + "docs": 0, # Total number of documents to be indexed + "batchs": 0, # Number of batches for processing documents + "cur_batch": 0, # Current processing batch + "request_pending": False, # Flag for pending request for processing + "latest_message": "", # Latest message from pipeline processing + "history_messages": history_messages, # 使用共享列表对象 + } + ) + + final_namespace = get_final_namespace("pipeline_status", workspace) + direct_log( + f"Process {os.getpid()} Pipeline namespace '{final_namespace}' initialized" + ) + + +async def get_update_flag(namespace: str, workspace: str | None = None): + """ + Create a namespace's update flag for a workers. + Returen the update flag to caller for referencing or reset. + """ + global _update_flags + if _update_flags is None: + raise ValueError("Try to create namespace before Shared-Data is initialized") + + final_namespace = get_final_namespace(namespace, workspace) + + async with get_internal_lock(): + if final_namespace not in _update_flags: + if _is_multiprocess and _manager is not None: + _update_flags[final_namespace] = _manager.list() + else: + _update_flags[final_namespace] = [] + direct_log( + f"Process {os.getpid()} initialized updated flags for namespace: [{final_namespace}]" + ) + + if _is_multiprocess and _manager is not None: + new_update_flag = _manager.Value("b", False) + else: + # Create a simple mutable object to store boolean value for compatibility with mutiprocess + class MutableBoolean: + def __init__(self, initial_value=False): + self.value = initial_value + + new_update_flag = MutableBoolean(False) + + _update_flags[final_namespace].append(new_update_flag) + return new_update_flag + + +async def set_all_update_flags(namespace: str, workspace: str | None = None): + """Set all update flag of namespace indicating all workers need to reload data from files""" + global _update_flags + if _update_flags is None: + raise ValueError("Try to create namespace before Shared-Data is initialized") + + final_namespace = get_final_namespace(namespace, workspace) + + async with get_internal_lock(): + if final_namespace not in _update_flags: + raise ValueError(f"Namespace {final_namespace} not found in update flags") + # Update flags for both modes + for i in range(len(_update_flags[final_namespace])): + _update_flags[final_namespace][i].value = True + + +async def clear_all_update_flags(namespace: str, workspace: str | None = None): + """Clear all update flag of namespace indicating all workers need to reload data from files""" + global _update_flags + if _update_flags is None: + raise ValueError("Try to create namespace before Shared-Data is initialized") + + final_namespace = get_final_namespace(namespace, workspace) + + async with get_internal_lock(): + if final_namespace not in _update_flags: + raise ValueError(f"Namespace {final_namespace} not found in update flags") + # Update flags for both modes + for i in range(len(_update_flags[final_namespace])): + _update_flags[final_namespace][i].value = False + + +async def get_all_update_flags_status(workspace: str | None = None) -> Dict[str, list]: + """ + Get update flags status for all namespaces. + + Returns: + Dict[str, list]: A dictionary mapping namespace names to lists of update flag statuses + """ + if _update_flags is None: + return {} + + if workspace is None: + workspace = get_default_workspace() + + result = {} + async with get_internal_lock(): + for namespace, flags in _update_flags.items(): + # Check if namespace has a workspace prefix (contains ':') + if ":" in namespace: + # Namespace has workspace prefix like "space1:pipeline_status" + # Only include if workspace matches the prefix + # Use rsplit to split from the right since workspace can contain colons + namespace_split = namespace.rsplit(":", 1) + if not workspace or namespace_split[0] != workspace: + continue + else: + # Namespace has no workspace prefix like "pipeline_status" + # Only include if we're querying the default (empty) workspace + if workspace: + continue + + worker_statuses = [] + for flag in flags: + if _is_multiprocess: + worker_statuses.append(flag.value) + else: + worker_statuses.append(flag) + result[namespace] = worker_statuses + + return result + + +async def try_initialize_namespace( + namespace: str, workspace: str | None = None +) -> bool: + """ + Returns True if the current worker(process) gets initialization permission for loading data later. + The worker does not get the permission is prohibited to load data from files. + """ + global _init_flags, _manager + + if _init_flags is None: + raise ValueError("Try to create nanmespace before Shared-Data is initialized") + + final_namespace = get_final_namespace(namespace, workspace) + + async with get_internal_lock(): + if final_namespace not in _init_flags: + _init_flags[final_namespace] = True + direct_log( + f"Process {os.getpid()} ready to initialize storage namespace: [{final_namespace}]" + ) + return True + direct_log( + f"Process {os.getpid()} storage namespace already initialized: [{final_namespace}]" + ) + + return False + + +async def get_namespace_data( + namespace: str, first_init: bool = False, workspace: str | None = None +) -> Dict[str, Any]: + """get the shared data reference for specific namespace + + Args: + namespace: The namespace to retrieve + first_init: If True, allows pipeline_status namespace to create namespace if it doesn't exist. + Prevent getting pipeline_status namespace without initialize_pipeline_status(). + This parameter is used internally by initialize_pipeline_status(). + workspace: Workspace identifier (may be empty string for global namespace) + """ + if _shared_dicts is None: + direct_log( + f"Error: Try to getnanmespace before it is initialized, pid={os.getpid()}", + level="ERROR", + ) + raise ValueError("Shared dictionaries not initialized") + + final_namespace = get_final_namespace(namespace, workspace) + + async with get_internal_lock(): + if final_namespace not in _shared_dicts: + # Special handling for pipeline_status namespace + if ( + final_namespace.endswith(":pipeline_status") + or final_namespace == "pipeline_status" + ) and not first_init: + # Check if pipeline_status should have been initialized but wasn't + # This helps users to call initialize_pipeline_status() before get_namespace_data() + raise PipelineNotInitializedError(final_namespace) + + # For other namespaces or when allow_create=True, create them dynamically + if _is_multiprocess and _manager is not None: + _shared_dicts[final_namespace] = _manager.dict() + else: + _shared_dicts[final_namespace] = {} + + return _shared_dicts[final_namespace] + + +class NamespaceLock: + """ + Reusable namespace lock wrapper that creates a fresh context on each use. + + This class solves the lock re-entrance and concurrent coroutine issues by using + contextvars.ContextVar to provide per-coroutine storage. Each coroutine gets its + own independent lock context, preventing state interference between concurrent + coroutines using the same NamespaceLock instance. + + Example: + lock = NamespaceLock("my_namespace", "workspace1") + + # Can be used multiple times safely + async with lock: + await do_something() + + # Can even be used concurrently without deadlock + await asyncio.gather( + coroutine_1(lock), # Each gets its own context + coroutine_2(lock) # No state interference + ) + """ + + def __init__( + self, namespace: str, workspace: str | None = None, enable_logging: bool = False + ): + self._namespace = namespace + self._workspace = workspace + self._enable_logging = enable_logging + # Use ContextVar to provide per-coroutine storage for lock context + # This ensures each coroutine has its own independent context + self._ctx_var: ContextVar[Optional[_KeyedLockContext]] = ContextVar( + "lock_ctx", default=None + ) + + async def __aenter__(self): + """Create a fresh context each time we enter""" + # Check if this coroutine already has an active lock context + if self._ctx_var.get() is not None: + raise RuntimeError( + "NamespaceLock already acquired in current coroutine context" + ) + + final_namespace = get_final_namespace(self._namespace, self._workspace) + ctx = get_storage_keyed_lock( + ["default_key"], + namespace=final_namespace, + enable_logging=self._enable_logging, + ) + + # Acquire the lock first, then store context only after successful acquisition + # This prevents the ContextVar from being set if acquisition fails (e.g., due to cancellation), + # which would permanently brick the lock + result = await ctx.__aenter__() + self._ctx_var.set(ctx) + return result + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Exit the current context and clean up""" + # Retrieve this coroutine's context + ctx = self._ctx_var.get() + if ctx is None: + raise RuntimeError("NamespaceLock exited without being entered") + + result = await ctx.__aexit__(exc_type, exc_val, exc_tb) + # Clear this coroutine's context + self._ctx_var.set(None) + return result + + +def get_namespace_lock( + namespace: str, workspace: str | None = None, enable_logging: bool = False +) -> NamespaceLock: + """Get a reusable namespace lock wrapper. + + This function returns a NamespaceLock instance that can be used multiple times + safely, even in concurrent scenarios. Each use creates a fresh lock context + internally, preventing lock re-entrance errors. + + Args: + namespace: The namespace to get the lock for. + workspace: Workspace identifier (may be empty string for global namespace) + enable_logging: Whether to enable lock operation logging + + Returns: + NamespaceLock: A reusable lock wrapper that can be used with 'async with' + + Example: + lock = get_namespace_lock("pipeline_status", workspace="space1") + + # Can be used multiple times + async with lock: + await do_something() + + async with lock: + await do_something_else() + """ + return NamespaceLock(namespace, workspace, enable_logging) + + +def finalize_share_data(): + """ + Release shared resources and clean up. + + This function should be called when the application is shutting down + to properly release shared resources and avoid memory leaks. + + In multi-process mode, it shuts down the Manager and releases all shared objects. + In single-process mode, it simply resets the global variables. + """ + global \ + _manager, \ + _is_multiprocess, \ + _internal_lock, \ + _data_init_lock, \ + _shared_dicts, \ + _init_flags, \ + _initialized, \ + _update_flags, \ + _async_locks, \ + _default_workspace + + # Check if already initialized + if not _initialized: + direct_log( + f"Process {os.getpid()} storage data not initialized, nothing to finalize" + ) + return + + direct_log( + f"Process {os.getpid()} finalizing storage data (multiprocess={_is_multiprocess})" + ) + + # In multi-process mode, shut down the Manager + if _is_multiprocess and _manager is not None: + try: + # Clear shared resources before shutting down Manager + if _shared_dicts is not None: + # Clear pipeline status history messages first if exists + try: + pipeline_status = _shared_dicts.get("pipeline_status", {}) + if "history_messages" in pipeline_status: + pipeline_status["history_messages"].clear() + except Exception: + pass # Ignore any errors during history messages cleanup + _shared_dicts.clear() + if _init_flags is not None: + _init_flags.clear() + if _update_flags is not None: + # Clear each namespace's update flags list and Value objects + try: + for namespace in _update_flags: + flags_list = _update_flags[namespace] + if isinstance(flags_list, list): + # Clear Value objects in the list + for flag in flags_list: + if hasattr( + flag, "value" + ): # Check if it's a Value object + flag.value = False + flags_list.clear() + except Exception: + pass # Ignore any errors during update flags cleanup + _update_flags.clear() + + # Shut down the Manager - this will automatically clean up all shared resources + _manager.shutdown() + direct_log(f"Process {os.getpid()} Manager shutdown complete") + except Exception as e: + direct_log( + f"Process {os.getpid()} Error shutting down Manager: {e}", level="ERROR" + ) + + # Reset global variables + _manager = None + _initialized = None + _is_multiprocess = None + _shared_dicts = None + _init_flags = None + _internal_lock = None + _data_init_lock = None + _update_flags = None + _async_locks = None + _default_workspace = None + + direct_log(f"Process {os.getpid()} storage data finalization complete") + + +def set_default_workspace(workspace: str | None = None): + """ + Set default workspace for namespace operations for backward compatibility. + + This allows get_namespace_data(),get_namespace_lock() or initialize_pipeline_status() to + automatically use the correct workspace when called without workspace parameters, + maintaining compatibility with legacy code that doesn't pass workspace explicitly. + + Args: + workspace: Workspace identifier (may be empty string for global namespace) + """ + global _default_workspace + if workspace is None: + workspace = "" + _default_workspace = workspace + direct_log( + f"Default workspace set to: '{_default_workspace}' (empty means global)", + level="DEBUG", + ) + + +def get_default_workspace() -> str: + """ + Get default workspace for backward compatibility. + + Returns: + The default workspace string. Empty string means global namespace. None means not set. + """ + global _default_workspace + return _default_workspace + + +def get_pipeline_status_lock( + enable_logging: bool = False, workspace: str = None +) -> NamespaceLock: + """Return unified storage lock for pipeline status data consistency. + + This function is for compatibility with legacy code only. + """ + global _default_workspace + actual_workspace = workspace if workspace else _default_workspace + return get_namespace_lock( + "pipeline_status", workspace=actual_workspace, enable_logging=enable_logging + ) diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index a4c067372f..50e873060b 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -703,7 +703,10 @@ async def nvidia_openai_complete( @wrap_embedding_func_with_attrs( - embedding_dim=1536, max_token_size=8192, model_name="text-embedding-3-small" + embedding_dim=3072, + max_token_size=8192, + model_name="text-embedding-3-large", + allow_extra_vectors=True, ) @retry( stop=stop_after_attempt(3), @@ -716,7 +719,7 @@ async def nvidia_openai_complete( ) async def openai_embed( texts: list[str], - model: str = "text-embedding-3-small", + model: str = "text-embedding-3-large", base_url: str | None = None, api_key: str | None = None, embedding_dim: int | None = None, @@ -931,9 +934,10 @@ async def azure_openai_complete( @wrap_embedding_func_with_attrs( - embedding_dim=1536, + embedding_dim=3072, max_token_size=8192, model_name="my-text-embedding-3-large-deployment", + allow_extra_vectors=True, ) async def azure_openai_embed( texts: list[str], diff --git a/lightrag/llm/openai_original.py b/lightrag/llm/openai_original.py new file mode 100644 index 0000000000..5bcb511f4b --- /dev/null +++ b/lightrag/llm/openai_original.py @@ -0,0 +1,1020 @@ +from ..utils import verbose_debug, VERBOSE_DEBUG +import os +import logging + +from collections.abc import AsyncIterator + +import pipmaster as pm +import tiktoken + +# install specific modules +if not pm.is_installed("openai"): + pm.install("openai") + +from openai import ( + APIConnectionError, + RateLimitError, + APITimeoutError, +) +from tenacity import ( + retry, + stop_after_attempt, + wait_exponential, + retry_if_exception_type, +) +from lightrag.utils import ( + wrap_embedding_func_with_attrs, + safe_unicode_decode, + logger, +) + +from lightrag.types import GPTKeywordExtractionFormat +from lightrag.api import __api_version__ + +import numpy as np +import base64 +from typing import Any, Union + +from dotenv import load_dotenv + +# Try to import Langfuse for LLM observability (optional) +# Falls back to standard OpenAI client if not available +# Langfuse requires proper configuration to work correctly +LANGFUSE_ENABLED = False +try: + # Check if required Langfuse environment variables are set + langfuse_public_key = os.environ.get("LANGFUSE_PUBLIC_KEY") + langfuse_secret_key = os.environ.get("LANGFUSE_SECRET_KEY") + + # Only enable Langfuse if both keys are configured + if langfuse_public_key and langfuse_secret_key: + from langfuse.openai import AsyncOpenAI # type: ignore[import-untyped] + + LANGFUSE_ENABLED = True + logger.info("Langfuse observability enabled for OpenAI client") + else: + from openai import AsyncOpenAI + + logger.debug( + "Langfuse environment variables not configured, using standard OpenAI client" + ) +except ImportError: + from openai import AsyncOpenAI + + logger.debug("Langfuse not available, using standard OpenAI client") + +# use the .env that is inside the current folder +# allows to use different .env file for each lightrag instance +# the OS environment variables take precedence over the .env file +load_dotenv(dotenv_path=".env", override=False) + + +class InvalidResponseError(Exception): + """Custom exception class for triggering retry mechanism""" + + pass + + +# Module-level cache for tiktoken encodings +_TIKTOKEN_ENCODING_CACHE: dict[str, Any] = {} + + +def _get_tiktoken_encoding_for_model(model: str) -> Any: + """Get tiktoken encoding for the specified model with caching. + + Args: + model: The model name to get encoding for. + + Returns: + The tiktoken encoding for the model. + """ + if model not in _TIKTOKEN_ENCODING_CACHE: + try: + _TIKTOKEN_ENCODING_CACHE[model] = tiktoken.encoding_for_model(model) + except KeyError: + logger.debug( + f"Encoding for model '{model}' not found, falling back to cl100k_base" + ) + _TIKTOKEN_ENCODING_CACHE[model] = tiktoken.get_encoding("cl100k_base") + return _TIKTOKEN_ENCODING_CACHE[model] + + +def create_openai_async_client( + api_key: str | None = None, + base_url: str | None = None, + use_azure: bool = False, + azure_deployment: str | None = None, + api_version: str | None = None, + timeout: int | None = None, + client_configs: dict[str, Any] | None = None, +) -> AsyncOpenAI: + """Create an AsyncOpenAI or AsyncAzureOpenAI client with the given configuration. + + Args: + api_key: OpenAI API key. If None, uses the OPENAI_API_KEY environment variable. + base_url: Base URL for the OpenAI API. If None, uses the default OpenAI API URL. + use_azure: Whether to create an Azure OpenAI client. Default is False. + azure_deployment: Azure OpenAI deployment name (only used when use_azure=True). + api_version: Azure OpenAI API version (only used when use_azure=True). + timeout: Request timeout in seconds. + client_configs: Additional configuration options for the AsyncOpenAI client. + These will override any default configurations but will be overridden by + explicit parameters (api_key, base_url). + + Returns: + An AsyncOpenAI or AsyncAzureOpenAI client instance. + """ + if use_azure: + from openai import AsyncAzureOpenAI + + if not api_key: + api_key = os.environ.get("AZURE_OPENAI_API_KEY") or os.environ.get( + "LLM_BINDING_API_KEY" + ) + + if client_configs is None: + client_configs = {} + + # Create a merged config dict with precedence: explicit params > client_configs + merged_configs = { + **client_configs, + "api_key": api_key, + } + + # Add explicit parameters (override client_configs) + if base_url is not None: + merged_configs["azure_endpoint"] = base_url + if azure_deployment is not None: + merged_configs["azure_deployment"] = azure_deployment + if api_version is not None: + merged_configs["api_version"] = api_version + if timeout is not None: + merged_configs["timeout"] = timeout + + return AsyncAzureOpenAI(**merged_configs) + else: + if not api_key: + api_key = os.environ["OPENAI_API_KEY"] + + default_headers = { + "User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", + "Content-Type": "application/json", + } + + if client_configs is None: + client_configs = {} + + # Create a merged config dict with precedence: explicit params > client_configs > defaults + merged_configs = { + **client_configs, + "default_headers": default_headers, + "api_key": api_key, + } + + if base_url is not None: + merged_configs["base_url"] = base_url + else: + merged_configs["base_url"] = os.environ.get( + "OPENAI_API_BASE", "https://api.openai.com/v1" + ) + + if timeout is not None: + merged_configs["timeout"] = timeout + + return AsyncOpenAI(**merged_configs) + + +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=( + retry_if_exception_type(RateLimitError) + | retry_if_exception_type(APIConnectionError) + | retry_if_exception_type(APITimeoutError) + | retry_if_exception_type(InvalidResponseError) + ), +) +async def openai_complete_if_cache( + model: str, + prompt: str, + system_prompt: str | None = None, + history_messages: list[dict[str, Any]] | None = None, + enable_cot: bool = False, + base_url: str | None = None, + api_key: str | None = None, + token_tracker: Any | None = None, + stream: bool | None = None, + timeout: int | None = None, + keyword_extraction: bool = False, + use_azure: bool = False, + azure_deployment: str | None = None, + api_version: str | None = None, + **kwargs: Any, +) -> str: + """Complete a prompt using OpenAI's API with caching support and Chain of Thought (COT) integration. + + This function supports automatic integration of reasoning content from models that provide + Chain of Thought capabilities. The reasoning content is seamlessly integrated into the response + using ... tags. + + Note on `reasoning_content`: This feature relies on a Deepseek Style `reasoning_content` + in the API response, which may be provided by OpenAI-compatible endpoints that support + Chain of Thought. + + COT Integration Rules: + 1. COT content is accepted only when regular content is empty and `reasoning_content` has content. + 2. COT processing stops when regular content becomes available. + 3. If both `content` and `reasoning_content` are present simultaneously, reasoning is ignored. + 4. If both fields have content from the start, COT is never activated. + 5. For streaming: COT content is inserted into the content stream with tags. + 6. For non-streaming: COT content is prepended to regular content with tags. + + Args: + model: The OpenAI model to use. For Azure, this can be the deployment name. + prompt: The prompt to complete. + system_prompt: Optional system prompt to include. + history_messages: Optional list of previous messages in the conversation. + enable_cot: Whether to enable Chain of Thought (COT) processing. Default is False. + base_url: Optional base URL for the OpenAI API. For Azure, this should be the + Azure OpenAI endpoint (e.g., https://your-resource.openai.azure.com/). + api_key: Optional API key. For standard OpenAI, uses OPENAI_API_KEY environment + variable if None. For Azure, uses AZURE_OPENAI_API_KEY if None. + token_tracker: Optional token usage tracker for monitoring API usage. + stream: Whether to stream the response. Default is False. + timeout: Request timeout in seconds. Default is None. + keyword_extraction: Whether to enable keyword extraction mode. When True, triggers + special response formatting for keyword extraction. Default is False. + use_azure: Whether to use Azure OpenAI service instead of standard OpenAI. + When True, creates an AsyncAzureOpenAI client. Default is False. + azure_deployment: Azure OpenAI deployment name. Only used when use_azure=True. + If not specified, falls back to AZURE_OPENAI_DEPLOYMENT environment variable. + api_version: Azure OpenAI API version (e.g., "2024-02-15-preview"). Only used + when use_azure=True. If not specified, falls back to AZURE_OPENAI_API_VERSION + environment variable. + **kwargs: Additional keyword arguments to pass to the OpenAI API. + Special kwargs: + - openai_client_configs: Dict of configuration options for the AsyncOpenAI client. + These will be passed to the client constructor but will be overridden by + explicit parameters (api_key, base_url). Supports proxy configuration, + custom headers, retry policies, etc. + + Returns: + The completed text (with integrated COT content if available) or an async iterator + of text chunks if streaming. COT content is wrapped in ... tags. + + Raises: + InvalidResponseError: If the response from OpenAI is invalid or empty. + APIConnectionError: If there is a connection error with the OpenAI API. + RateLimitError: If the OpenAI API rate limit is exceeded. + APITimeoutError: If the OpenAI API request times out. + """ + if history_messages is None: + history_messages = [] + + # Set openai logger level to INFO when VERBOSE_DEBUG is off + if not VERBOSE_DEBUG and logger.level == logging.DEBUG: + logging.getLogger("openai").setLevel(logging.INFO) + + # Remove special kwargs that shouldn't be passed to OpenAI + kwargs.pop("hashing_kv", None) + + # Extract client configuration options + client_configs = kwargs.pop("openai_client_configs", {}) + + # Handle keyword extraction mode + if keyword_extraction: + kwargs["response_format"] = GPTKeywordExtractionFormat + + # Create the OpenAI client (supports both OpenAI and Azure) + openai_async_client = create_openai_async_client( + api_key=api_key, + base_url=base_url, + use_azure=use_azure, + azure_deployment=azure_deployment, + api_version=api_version, + timeout=timeout, + client_configs=client_configs, + ) + + # Prepare messages + messages: list[dict[str, Any]] = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) + + logger.debug("===== Entering func of LLM =====") + logger.debug(f"Model: {model} Base URL: {base_url}") + logger.debug(f"Client Configs: {client_configs}") + logger.debug(f"Additional kwargs: {kwargs}") + logger.debug(f"Num of history messages: {len(history_messages)}") + verbose_debug(f"System prompt: {system_prompt}") + verbose_debug(f"Query: {prompt}") + logger.debug("===== Sending Query to LLM =====") + + messages = kwargs.pop("messages", messages) + + # Add explicit parameters back to kwargs so they're passed to OpenAI API + if stream is not None: + kwargs["stream"] = stream + if timeout is not None: + kwargs["timeout"] = timeout + + # Determine the correct model identifier to use + # For Azure OpenAI, we must use the deployment name instead of the model name + api_model = azure_deployment if use_azure and azure_deployment else model + + try: + # Don't use async with context manager, use client directly + if "response_format" in kwargs: + response = await openai_async_client.chat.completions.parse( + model=api_model, messages=messages, **kwargs + ) + else: + response = await openai_async_client.chat.completions.create( + model=api_model, messages=messages, **kwargs + ) + except APITimeoutError as e: + logger.error(f"OpenAI API Timeout Error: {e}") + await openai_async_client.close() # Ensure client is closed + raise + except APIConnectionError as e: + logger.error(f"OpenAI API Connection Error: {e}") + await openai_async_client.close() # Ensure client is closed + raise + except RateLimitError as e: + logger.error(f"OpenAI API Rate Limit Error: {e}") + await openai_async_client.close() # Ensure client is closed + raise + except Exception as e: + logger.error( + f"OpenAI API Call Failed,\nModel: {model},\nParams: {kwargs}, Got: {e}" + ) + await openai_async_client.close() # Ensure client is closed + raise + + if hasattr(response, "__aiter__"): + + async def inner(): + # Track if we've started iterating + iteration_started = False + final_chunk_usage = None + + # COT (Chain of Thought) state tracking + cot_active = False + cot_started = False + initial_content_seen = False + + try: + iteration_started = True + async for chunk in response: + # Check if this chunk has usage information (final chunk) + if hasattr(chunk, "usage") and chunk.usage: + final_chunk_usage = chunk.usage + logger.debug( + f"Received usage info in streaming chunk: {chunk.usage}" + ) + + # Check if choices exists and is not empty + if not hasattr(chunk, "choices") or not chunk.choices: + # Azure OpenAI sends content filter results in first chunk without choices + logger.debug( + f"Received chunk without choices (likely Azure content filter): {chunk}" + ) + continue + + # Check if delta exists + if not hasattr(chunk.choices[0], "delta"): + # This might be the final chunk, continue to check for usage + continue + + delta = chunk.choices[0].delta + content = getattr(delta, "content", None) + reasoning_content = getattr(delta, "reasoning_content", "") + + # Handle COT logic for streaming (only if enabled) + if enable_cot: + if content: + # Regular content is present + if not initial_content_seen: + initial_content_seen = True + # If both content and reasoning_content are present initially, don't start COT + if reasoning_content: + cot_active = False + cot_started = False + + # If COT was active, end it + if cot_active: + yield "" + cot_active = False + + # Process regular content + if r"\u" in content: + content = safe_unicode_decode(content.encode("utf-8")) + yield content + + elif reasoning_content: + # Only reasoning content is present + if not initial_content_seen and not cot_started: + # Start COT if we haven't seen initial content yet + if not cot_active: + yield "" + cot_active = True + cot_started = True + + # Process reasoning content if COT is active + if cot_active: + if r"\u" in reasoning_content: + reasoning_content = safe_unicode_decode( + reasoning_content.encode("utf-8") + ) + yield reasoning_content + else: + # COT disabled, only process regular content + if content: + if r"\u" in content: + content = safe_unicode_decode(content.encode("utf-8")) + yield content + + # If neither content nor reasoning_content, continue to next chunk + if content is None and reasoning_content is None: + continue + + # Ensure COT is properly closed if still active after stream ends + if enable_cot and cot_active: + yield "" + cot_active = False + + # After streaming is complete, track token usage + if token_tracker and final_chunk_usage: + # Use actual usage from the API + token_counts = { + "prompt_tokens": getattr(final_chunk_usage, "prompt_tokens", 0), + "completion_tokens": getattr( + final_chunk_usage, "completion_tokens", 0 + ), + "total_tokens": getattr(final_chunk_usage, "total_tokens", 0), + } + token_tracker.add_usage(token_counts) + logger.debug(f"Streaming token usage (from API): {token_counts}") + elif token_tracker: + logger.debug("No usage information available in streaming response") + except Exception as e: + # Ensure COT is properly closed before handling exception + if enable_cot and cot_active: + try: + yield "" + cot_active = False + except Exception as close_error: + logger.warning( + f"Failed to close COT tag during exception handling: {close_error}" + ) + + logger.error(f"Error in stream response: {str(e)}") + # Try to clean up resources if possible + if ( + iteration_started + and hasattr(response, "aclose") + and callable(getattr(response, "aclose", None)) + ): + try: + await response.aclose() + logger.debug("Successfully closed stream response after error") + except Exception as close_error: + logger.warning( + f"Failed to close stream response: {close_error}" + ) + # Ensure client is closed in case of exception + await openai_async_client.close() + raise + finally: + # Final safety check for unclosed COT tags + if enable_cot and cot_active: + try: + yield "" + cot_active = False + except Exception as final_close_error: + logger.warning( + f"Failed to close COT tag in finally block: {final_close_error}" + ) + + # Ensure resources are released even if no exception occurs + # Note: Some wrapped clients (e.g., Langfuse) may not implement aclose() properly + if iteration_started and hasattr(response, "aclose"): + aclose_method = getattr(response, "aclose", None) + if callable(aclose_method): + try: + await response.aclose() + logger.debug("Successfully closed stream response") + except (AttributeError, TypeError) as close_error: + # Some wrapper objects may report hasattr(aclose) but fail when called + # This is expected behavior for certain client wrappers + logger.debug( + f"Stream response cleanup not supported by client wrapper: {close_error}" + ) + except Exception as close_error: + logger.warning( + f"Unexpected error during stream response cleanup: {close_error}" + ) + + # This prevents resource leaks since the caller doesn't handle closing + try: + await openai_async_client.close() + logger.debug( + "Successfully closed OpenAI client for streaming response" + ) + except Exception as client_close_error: + logger.warning( + f"Failed to close OpenAI client in streaming finally block: {client_close_error}" + ) + + return inner() + + else: + try: + if ( + not response + or not response.choices + or not hasattr(response.choices[0], "message") + ): + logger.error("Invalid response from OpenAI API") + await openai_async_client.close() # Ensure client is closed + raise InvalidResponseError("Invalid response from OpenAI API") + + message = response.choices[0].message + + # Handle parsed responses (structured output via response_format) + # When using beta.chat.completions.parse(), the response is in message.parsed + if hasattr(message, "parsed") and message.parsed is not None: + # Serialize the parsed structured response to JSON + final_content = message.parsed.model_dump_json() + logger.debug("Using parsed structured response from API") + else: + # Handle regular content responses + content = getattr(message, "content", None) + reasoning_content = getattr(message, "reasoning_content", "") + + # Handle COT logic for non-streaming responses (only if enabled) + final_content = "" + + if enable_cot: + # Check if we should include reasoning content + should_include_reasoning = False + if reasoning_content and reasoning_content.strip(): + if not content or content.strip() == "": + # Case 1: Only reasoning content, should include COT + should_include_reasoning = True + final_content = ( + content or "" + ) # Use empty string if content is None + else: + # Case 3: Both content and reasoning_content present, ignore reasoning + should_include_reasoning = False + final_content = content + else: + # No reasoning content, use regular content + final_content = content or "" + + # Apply COT wrapping if needed + if should_include_reasoning: + if r"\u" in reasoning_content: + reasoning_content = safe_unicode_decode( + reasoning_content.encode("utf-8") + ) + final_content = ( + f"{reasoning_content}{final_content}" + ) + else: + # COT disabled, only use regular content + final_content = content or "" + + # Validate final content + if not final_content or final_content.strip() == "": + logger.error("Received empty content from OpenAI API") + await openai_async_client.close() # Ensure client is closed + raise InvalidResponseError("Received empty content from OpenAI API") + + # Apply Unicode decoding to final content if needed + if r"\u" in final_content: + final_content = safe_unicode_decode(final_content.encode("utf-8")) + + if token_tracker and hasattr(response, "usage"): + token_counts = { + "prompt_tokens": getattr(response.usage, "prompt_tokens", 0), + "completion_tokens": getattr( + response.usage, "completion_tokens", 0 + ), + "total_tokens": getattr(response.usage, "total_tokens", 0), + } + token_tracker.add_usage(token_counts) + + logger.debug(f"Response content len: {len(final_content)}") + verbose_debug(f"Response: {response}") + + return final_content + finally: + # Ensure client is closed in all cases for non-streaming responses + await openai_async_client.close() + + +async def openai_complete( + prompt, + system_prompt=None, + history_messages=None, + keyword_extraction=False, + **kwargs, +) -> Union[str, AsyncIterator[str]]: + if history_messages is None: + history_messages = [] + model_name = kwargs["hashing_kv"].global_config["llm_model_name"] + return await openai_complete_if_cache( + model_name, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + keyword_extraction=keyword_extraction, + **kwargs, + ) + + +async def gpt_4o_complete( + prompt, + system_prompt=None, + history_messages=None, + enable_cot: bool = False, + keyword_extraction=False, + **kwargs, +) -> str: + if history_messages is None: + history_messages = [] + return await openai_complete_if_cache( + "gpt-4o", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + enable_cot=enable_cot, + keyword_extraction=keyword_extraction, + **kwargs, + ) + + +async def gpt_4o_mini_complete( + prompt, + system_prompt=None, + history_messages=None, + enable_cot: bool = False, + keyword_extraction=False, + **kwargs, +) -> str: + if history_messages is None: + history_messages = [] + return await openai_complete_if_cache( + "gpt-4o-mini", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + enable_cot=enable_cot, + keyword_extraction=keyword_extraction, + **kwargs, + ) + + +async def nvidia_openai_complete( + prompt, + system_prompt=None, + history_messages=None, + enable_cot: bool = False, + keyword_extraction=False, + **kwargs, +) -> str: + if history_messages is None: + history_messages = [] + result = await openai_complete_if_cache( + "nvidia/llama-3.1-nemotron-70b-instruct", # context length 128k + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + enable_cot=enable_cot, + keyword_extraction=keyword_extraction, + base_url="https://integrate.api.nvidia.com/v1", + **kwargs, + ) + return result + + +@wrap_embedding_func_with_attrs( + embedding_dim=1536, max_token_size=8192, model_name="text-embedding-3-small" +) +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=60), + retry=( + retry_if_exception_type(RateLimitError) + | retry_if_exception_type(APIConnectionError) + | retry_if_exception_type(APITimeoutError) + ), +) +async def openai_embed( + texts: list[str], + model: str = "text-embedding-3-small", + base_url: str | None = None, + api_key: str | None = None, + embedding_dim: int | None = None, + max_token_size: int | None = None, + client_configs: dict[str, Any] | None = None, + token_tracker: Any | None = None, + use_azure: bool = False, + azure_deployment: str | None = None, + api_version: str | None = None, +) -> np.ndarray: + """Generate embeddings for a list of texts using OpenAI's API with automatic text truncation. + + This function supports both standard OpenAI and Azure OpenAI services. It automatically + truncates texts that exceed the model's token limit to prevent API errors. + + Args: + texts: List of texts to embed. + model: The embedding model to use. For standard OpenAI (e.g., "text-embedding-3-small"). + For Azure, this can be the deployment name. + base_url: Optional base URL for the API. For standard OpenAI, uses default OpenAI endpoint. + For Azure, this should be the Azure OpenAI endpoint (e.g., https://your-resource.openai.azure.com/). + api_key: Optional API key. For standard OpenAI, uses OPENAI_API_KEY environment variable if None. + For Azure, uses AZURE_EMBEDDING_API_KEY environment variable if None. + embedding_dim: Optional embedding dimension for dynamic dimension reduction. + **IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper. + Do NOT manually pass this parameter when calling the function directly. + The dimension is controlled by the @wrap_embedding_func_with_attrs decorator. + Manually passing a different value will trigger a warning and be ignored. + When provided (by EmbeddingFunc), it will be passed to the OpenAI API for dimension reduction. + max_token_size: Maximum tokens per text. Texts exceeding this limit will be truncated. + **IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper + when the underlying function signature supports it (via inspect.signature check). + The value is controlled by the @wrap_embedding_func_with_attrs decorator. + Set max_token_size=0 to disable truncation. + client_configs: Additional configuration options for the AsyncOpenAI/AsyncAzureOpenAI client. + These will override any default configurations but will be overridden by + explicit parameters (api_key, base_url). Supports proxy configuration, + custom headers, retry policies, etc. + token_tracker: Optional token usage tracker for monitoring API usage. + use_azure: Whether to use Azure OpenAI service instead of standard OpenAI. + When True, creates an AsyncAzureOpenAI client. Default is False. + azure_deployment: Azure OpenAI deployment name. Only used when use_azure=True. + If not specified, falls back to AZURE_EMBEDDING_DEPLOYMENT environment variable. + api_version: Azure OpenAI API version (e.g., "2024-02-15-preview"). Only used + when use_azure=True. If not specified, falls back to AZURE_EMBEDDING_API_VERSION + environment variable. + + Returns: + A numpy array of embeddings, one per input text. + + Raises: + APIConnectionError: If there is a connection error with the OpenAI API. + RateLimitError: If the OpenAI API rate limit is exceeded. + APITimeoutError: If the OpenAI API request times out. + """ + # Apply text truncation if max_token_size is provided + if max_token_size is not None and max_token_size > 0: + encoding = _get_tiktoken_encoding_for_model(model) + truncated_texts = [] + truncation_count = 0 + + for text in texts: + if not text: + truncated_texts.append(text) + continue + + tokens = encoding.encode(text) + if len(tokens) > max_token_size: + truncated_tokens = tokens[:max_token_size] + truncated_texts.append(encoding.decode(truncated_tokens)) + truncation_count += 1 + logger.debug( + f"Text truncated from {len(tokens)} to {max_token_size} tokens" + ) + else: + truncated_texts.append(text) + + if truncation_count > 0: + logger.info( + f"Truncated {truncation_count}/{len(texts)} texts to fit token limit ({max_token_size})" + ) + + texts = truncated_texts + + # Create the OpenAI client (supports both OpenAI and Azure) + openai_async_client = create_openai_async_client( + api_key=api_key, + base_url=base_url, + use_azure=use_azure, + azure_deployment=azure_deployment, + api_version=api_version, + client_configs=client_configs, + ) + + async with openai_async_client: + # Determine the correct model identifier to use + # For Azure OpenAI, we must use the deployment name instead of the model name + api_model = azure_deployment if use_azure and azure_deployment else model + + # Prepare API call parameters + api_params = { + "model": api_model, + "input": texts, + "encoding_format": "base64", + } + + # Add dimensions parameter only if embedding_dim is provided + if embedding_dim is not None: + api_params["dimensions"] = embedding_dim + + # Make API call + response = await openai_async_client.embeddings.create(**api_params) + + if token_tracker and hasattr(response, "usage"): + token_counts = { + "prompt_tokens": getattr(response.usage, "prompt_tokens", 0), + "total_tokens": getattr(response.usage, "total_tokens", 0), + } + token_tracker.add_usage(token_counts) + + return np.array( + [ + np.array(dp.embedding, dtype=np.float32) + if isinstance(dp.embedding, list) + else np.frombuffer(base64.b64decode(dp.embedding), dtype=np.float32) + for dp in response.data + ] + ) + + +# Azure OpenAI wrapper functions for backward compatibility +async def azure_openai_complete_if_cache( + model, + prompt, + system_prompt: str | None = None, + history_messages: list[dict[str, Any]] | None = None, + enable_cot: bool = False, + base_url: str | None = None, + api_key: str | None = None, + token_tracker: Any | None = None, + stream: bool | None = None, + timeout: int | None = None, + api_version: str | None = None, + keyword_extraction: bool = False, + **kwargs, +): + """Azure OpenAI completion wrapper function. + + This function provides backward compatibility by wrapping the unified + openai_complete_if_cache implementation with Azure-specific parameter handling. + + All parameters from the underlying openai_complete_if_cache are exposed to ensure + full feature parity and API consistency. + """ + # Handle Azure-specific environment variables and parameters + deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") or model or os.getenv("LLM_MODEL") + base_url = ( + base_url or os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("LLM_BINDING_HOST") + ) + api_key = ( + api_key or os.getenv("AZURE_OPENAI_API_KEY") or os.getenv("LLM_BINDING_API_KEY") + ) + api_version = ( + api_version + or os.getenv("AZURE_OPENAI_API_VERSION") + or os.getenv("OPENAI_API_VERSION") + or "2024-08-01-preview" + ) + + # Call the unified implementation with Azure-specific parameters + return await openai_complete_if_cache( + model=deployment, + prompt=prompt, + system_prompt=system_prompt, + history_messages=history_messages, + enable_cot=enable_cot, + base_url=base_url, + api_key=api_key, + token_tracker=token_tracker, + stream=stream, + timeout=timeout, + use_azure=True, + azure_deployment=deployment, + api_version=api_version, + keyword_extraction=keyword_extraction, + **kwargs, + ) + + +async def azure_openai_complete( + prompt, + system_prompt=None, + history_messages=None, + keyword_extraction=False, + **kwargs, +) -> str: + """Azure OpenAI complete wrapper function. + + Provides backward compatibility for azure_openai_complete calls. + """ + if history_messages is None: + history_messages = [] + result = await azure_openai_complete_if_cache( + os.getenv("LLM_MODEL", "gpt-4o-mini"), + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + keyword_extraction=keyword_extraction, + **kwargs, + ) + return result + + +@wrap_embedding_func_with_attrs( + embedding_dim=1536, + max_token_size=8192, + model_name="my-text-embedding-3-large-deployment", +) +async def azure_openai_embed( + texts: list[str], + model: str | None = None, + base_url: str | None = None, + api_key: str | None = None, + token_tracker: Any | None = None, + client_configs: dict[str, Any] | None = None, + api_version: str | None = None, +) -> np.ndarray: + """Azure OpenAI embedding wrapper function. + + This function provides backward compatibility by wrapping the unified + openai_embed implementation with Azure-specific parameter handling. + + All parameters from the underlying openai_embed are exposed to ensure + full feature parity and API consistency. + + IMPORTANT - Decorator Usage: + + 1. This function is decorated with @wrap_embedding_func_with_attrs to provide + the EmbeddingFunc interface for users who need to access embedding_dim + and other attributes. + + 2. This function does NOT use @retry decorator to avoid double-wrapping, + since the underlying openai_embed.func already has retry logic. + + 3. This function calls openai_embed.func (the unwrapped function) instead of + openai_embed (the EmbeddingFunc instance) to avoid double decoration issues: + + ✅ Correct: await openai_embed.func(...) # Calls unwrapped function with retry + ❌ Wrong: await openai_embed(...) # Would cause double EmbeddingFunc wrapping + + Double decoration causes: + - Double injection of embedding_dim parameter + - Incorrect parameter passing to the underlying implementation + - Runtime errors due to parameter conflicts + + The call chain with correct implementation: + azure_openai_embed(texts) + → EmbeddingFunc.__call__(texts) # azure's decorator + → azure_openai_embed_impl(texts, embedding_dim=1536) + → openai_embed.func(texts, ...) + → @retry_wrapper(texts, ...) # openai's retry (only one layer) + → openai_embed_impl(texts, ...) + → actual embedding computation + """ + # Handle Azure-specific environment variables and parameters + deployment = ( + os.getenv("AZURE_EMBEDDING_DEPLOYMENT") + or model + or os.getenv("EMBEDDING_MODEL", "text-embedding-3-small") + ) + base_url = ( + base_url + or os.getenv("AZURE_EMBEDDING_ENDPOINT") + or os.getenv("EMBEDDING_BINDING_HOST") + ) + api_key = ( + api_key + or os.getenv("AZURE_EMBEDDING_API_KEY") + or os.getenv("EMBEDDING_BINDING_API_KEY") + ) + api_version = ( + api_version + or os.getenv("AZURE_EMBEDDING_API_VERSION") + or os.getenv("AZURE_OPENAI_API_VERSION") + or os.getenv("OPENAI_API_VERSION") + or "2024-08-01-preview" + ) + + # CRITICAL: Call openai_embed.func (unwrapped) to avoid double decoration + # openai_embed is an EmbeddingFunc instance, .func accesses the underlying function + return await openai_embed.func( + texts=texts, + model=deployment, + base_url=base_url, + api_key=api_key, + token_tracker=token_tracker, + client_configs=client_configs, + use_azure=True, + azure_deployment=deployment, + api_version=api_version, + ) diff --git a/lightrag/operate.py b/lightrag/operate.py index 20af067908..f723ccb77a 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -4087,6 +4087,11 @@ async def _build_context_str( text_units_str = "\n".join( json.dumps(text_unit, ensure_ascii=False) for text_unit in chunks_context ) + + if query_param.only_kg_context: + text_units_str = "" + reference_list_str = "" + reference_list_str = "\n".join( f"[{ref['reference_id']}] {ref['file_path']}" for ref in reference_list diff --git a/lightrag/operate_original.py b/lightrag/operate_original.py new file mode 100644 index 0000000000..5389672ed9 --- /dev/null +++ b/lightrag/operate_original.py @@ -0,0 +1,5024 @@ +from __future__ import annotations +from functools import partial +from pathlib import Path + +import asyncio +import json +import json_repair +from typing import Any, AsyncIterator, overload, Literal +from collections import Counter, defaultdict + +from lightrag.exceptions import ( + PipelineCancelledException, + ChunkTokenLimitExceededError, +) +from lightrag.utils import ( + logger, + compute_mdhash_id, + Tokenizer, + is_float_regex, + sanitize_and_normalize_extracted_text, + pack_user_ass_to_openai_messages, + split_string_by_multi_markers, + truncate_list_by_token_size, + compute_args_hash, + handle_cache, + save_to_cache, + CacheData, + use_llm_func_with_cache, + update_chunk_cache_list, + remove_think_tags, + pick_by_weighted_polling, + pick_by_vector_similarity, + process_chunks_unified, + safe_vdb_operation_with_exception, + create_prefixed_exception, + fix_tuple_delimiter_corruption, + convert_to_user_format, + generate_reference_list_from_chunks, + apply_source_ids_limit, + merge_source_ids, + make_relation_chunk_key, +) +from lightrag.base import ( + BaseGraphStorage, + BaseKVStorage, + BaseVectorStorage, + TextChunkSchema, + QueryParam, + QueryResult, + QueryContextResult, +) +from lightrag.prompt import PROMPTS +from lightrag.constants import ( + GRAPH_FIELD_SEP, + DEFAULT_MAX_ENTITY_TOKENS, + DEFAULT_MAX_RELATION_TOKENS, + DEFAULT_MAX_TOTAL_TOKENS, + DEFAULT_RELATED_CHUNK_NUMBER, + DEFAULT_KG_CHUNK_PICK_METHOD, + DEFAULT_ENTITY_TYPES, + DEFAULT_SUMMARY_LANGUAGE, + SOURCE_IDS_LIMIT_METHOD_KEEP, + SOURCE_IDS_LIMIT_METHOD_FIFO, + DEFAULT_FILE_PATH_MORE_PLACEHOLDER, + DEFAULT_MAX_FILE_PATHS, + DEFAULT_ENTITY_NAME_MAX_LENGTH, +) +from lightrag.kg.shared_storage import get_storage_keyed_lock +import time +from dotenv import load_dotenv + +# use the .env that is inside the current folder +# allows to use different .env file for each lightrag instance +# the OS environment variables take precedence over the .env file +load_dotenv(dotenv_path=Path(__file__).resolve().parent / ".env", override=False) + + +def _truncate_entity_identifier( + identifier: str, limit: int, chunk_key: str, identifier_role: str +) -> str: + """Truncate entity identifiers that exceed the configured length limit.""" + + if len(identifier) <= limit: + return identifier + + display_value = identifier[:limit] + preview = identifier[:20] # Show first 20 characters as preview + logger.warning( + "%s: %s len %d > %d chars (Name: '%s...')", + chunk_key, + identifier_role, + len(identifier), + limit, + preview, + ) + return display_value + + +def chunking_by_token_size( + tokenizer: Tokenizer, + content: str, + split_by_character: str | None = None, + split_by_character_only: bool = False, + chunk_overlap_token_size: int = 100, + chunk_token_size: int = 1200, +) -> list[dict[str, Any]]: + tokens = tokenizer.encode(content) + results: list[dict[str, Any]] = [] + if split_by_character: + raw_chunks = content.split(split_by_character) + new_chunks = [] + if split_by_character_only: + for chunk in raw_chunks: + _tokens = tokenizer.encode(chunk) + if len(_tokens) > chunk_token_size: + logger.warning( + "Chunk split_by_character exceeds token limit: len=%d limit=%d", + len(_tokens), + chunk_token_size, + ) + raise ChunkTokenLimitExceededError( + chunk_tokens=len(_tokens), + chunk_token_limit=chunk_token_size, + chunk_preview=chunk[:120], + ) + new_chunks.append((len(_tokens), chunk)) + else: + for chunk in raw_chunks: + _tokens = tokenizer.encode(chunk) + if len(_tokens) > chunk_token_size: + for start in range( + 0, len(_tokens), chunk_token_size - chunk_overlap_token_size + ): + chunk_content = tokenizer.decode( + _tokens[start : start + chunk_token_size] + ) + new_chunks.append( + (min(chunk_token_size, len(_tokens) - start), chunk_content) + ) + else: + new_chunks.append((len(_tokens), chunk)) + for index, (_len, chunk) in enumerate(new_chunks): + results.append( + { + "tokens": _len, + "content": chunk.strip(), + "chunk_order_index": index, + } + ) + else: + for index, start in enumerate( + range(0, len(tokens), chunk_token_size - chunk_overlap_token_size) + ): + chunk_content = tokenizer.decode(tokens[start : start + chunk_token_size]) + results.append( + { + "tokens": min(chunk_token_size, len(tokens) - start), + "content": chunk_content.strip(), + "chunk_order_index": index, + } + ) + return results + + +async def _handle_entity_relation_summary( + description_type: str, + entity_or_relation_name: str, + description_list: list[str], + separator: str, + global_config: dict, + llm_response_cache: BaseKVStorage | None = None, +) -> tuple[str, bool]: + """Handle entity relation description summary using map-reduce approach. + + This function summarizes a list of descriptions using a map-reduce strategy: + 1. If total tokens < summary_context_size and len(description_list) < force_llm_summary_on_merge, no need to summarize + 2. If total tokens < summary_max_tokens, summarize with LLM directly + 3. Otherwise, split descriptions into chunks that fit within token limits + 4. Summarize each chunk, then recursively process the summaries + 5. Continue until we get a final summary within token limits or num of descriptions is less than force_llm_summary_on_merge + + Args: + entity_or_relation_name: Name of the entity or relation being summarized + description_list: List of description strings to summarize + global_config: Global configuration containing tokenizer and limits + llm_response_cache: Optional cache for LLM responses + + Returns: + Tuple of (final_summarized_description_string, llm_was_used_boolean) + """ + # Handle empty input + if not description_list: + return "", False + + # If only one description, return it directly (no need for LLM call) + if len(description_list) == 1: + return description_list[0], False + + # Get configuration + tokenizer: Tokenizer = global_config["tokenizer"] + summary_context_size = global_config["summary_context_size"] + summary_max_tokens = global_config["summary_max_tokens"] + force_llm_summary_on_merge = global_config["force_llm_summary_on_merge"] + + current_list = description_list[:] # Copy the list to avoid modifying original + llm_was_used = False # Track whether LLM was used during the entire process + + # Iterative map-reduce process + while True: + # Calculate total tokens in current list + total_tokens = sum(len(tokenizer.encode(desc)) for desc in current_list) + + # If total length is within limits, perform final summarization + if total_tokens <= summary_context_size or len(current_list) <= 2: + if ( + len(current_list) < force_llm_summary_on_merge + and total_tokens < summary_max_tokens + ): + # no LLM needed, just join the descriptions + final_description = separator.join(current_list) + return final_description if final_description else "", llm_was_used + else: + if total_tokens > summary_context_size and len(current_list) <= 2: + logger.warning( + f"Summarizing {entity_or_relation_name}: Oversize description found" + ) + # Final summarization of remaining descriptions - LLM will be used + final_summary = await _summarize_descriptions( + description_type, + entity_or_relation_name, + current_list, + global_config, + llm_response_cache, + ) + return final_summary, True # LLM was used for final summarization + + # Need to split into chunks - Map phase + # Ensure each chunk has minimum 2 descriptions to guarantee progress + chunks = [] + current_chunk = [] + current_tokens = 0 + + # Currently least 3 descriptions in current_list + for i, desc in enumerate(current_list): + desc_tokens = len(tokenizer.encode(desc)) + + # If adding current description would exceed limit, finalize current chunk + if current_tokens + desc_tokens > summary_context_size and current_chunk: + # Ensure we have at least 2 descriptions in the chunk (when possible) + if len(current_chunk) == 1: + # Force add one more description to ensure minimum 2 per chunk + current_chunk.append(desc) + chunks.append(current_chunk) + logger.warning( + f"Summarizing {entity_or_relation_name}: Oversize description found" + ) + current_chunk = [] # next group is empty + current_tokens = 0 + else: # curren_chunk is ready for summary in reduce phase + chunks.append(current_chunk) + current_chunk = [desc] # leave it for next group + current_tokens = desc_tokens + else: + current_chunk.append(desc) + current_tokens += desc_tokens + + # Add the last chunk if it exists + if current_chunk: + chunks.append(current_chunk) + + logger.info( + f" Summarizing {entity_or_relation_name}: Map {len(current_list)} descriptions into {len(chunks)} groups" + ) + + # Reduce phase: summarize each group from chunks + new_summaries = [] + for chunk in chunks: + if len(chunk) == 1: + # Optimization: single description chunks don't need LLM summarization + new_summaries.append(chunk[0]) + else: + # Multiple descriptions need LLM summarization + summary = await _summarize_descriptions( + description_type, + entity_or_relation_name, + chunk, + global_config, + llm_response_cache, + ) + new_summaries.append(summary) + llm_was_used = True # Mark that LLM was used in reduce phase + + # Update current list with new summaries for next iteration + current_list = new_summaries + + +async def _summarize_descriptions( + description_type: str, + description_name: str, + description_list: list[str], + global_config: dict, + llm_response_cache: BaseKVStorage | None = None, +) -> str: + """Helper function to summarize a list of descriptions using LLM. + + Args: + entity_or_relation_name: Name of the entity or relation being summarized + descriptions: List of description strings to summarize + global_config: Global configuration containing LLM function and settings + llm_response_cache: Optional cache for LLM responses + + Returns: + Summarized description string + """ + use_llm_func: callable = global_config["llm_model_func"] + # Apply higher priority (8) to entity/relation summary tasks + use_llm_func = partial(use_llm_func, _priority=8) + + language = global_config["addon_params"].get("language", DEFAULT_SUMMARY_LANGUAGE) + + summary_length_recommended = global_config["summary_length_recommended"] + + prompt_template = PROMPTS["summarize_entity_descriptions"] + + # Convert descriptions to JSONL format and apply token-based truncation + tokenizer = global_config["tokenizer"] + summary_context_size = global_config["summary_context_size"] + + # Create list of JSON objects with "Description" field + json_descriptions = [{"Description": desc} for desc in description_list] + + # Use truncate_list_by_token_size for length truncation + truncated_json_descriptions = truncate_list_by_token_size( + json_descriptions, + key=lambda x: json.dumps(x, ensure_ascii=False), + max_token_size=summary_context_size, + tokenizer=tokenizer, + ) + + # Convert to JSONL format (one JSON object per line) + joined_descriptions = "\n".join( + json.dumps(desc, ensure_ascii=False) for desc in truncated_json_descriptions + ) + + # Prepare context for the prompt + context_base = dict( + description_type=description_type, + description_name=description_name, + description_list=joined_descriptions, + summary_length=summary_length_recommended, + language=language, + ) + use_prompt = prompt_template.format(**context_base) + + # Use LLM function with cache (higher priority for summary generation) + summary, _ = await use_llm_func_with_cache( + use_prompt, + use_llm_func, + llm_response_cache=llm_response_cache, + cache_type="summary", + ) + + # Check summary token length against embedding limit + embedding_token_limit = global_config.get("embedding_token_limit") + if embedding_token_limit is not None and summary: + tokenizer = global_config["tokenizer"] + summary_token_count = len(tokenizer.encode(summary)) + threshold = int(embedding_token_limit) + + if summary_token_count > threshold: + logger.warning( + f"Summary tokens({summary_token_count}) exceeds embedding_token_limit({embedding_token_limit}) " + f" for {description_type}: {description_name}" + ) + + return summary + + +async def _handle_single_entity_extraction( + record_attributes: list[str], + chunk_key: str, + timestamp: int, + file_path: str = "unknown_source", +): + if len(record_attributes) != 4 or "entity" not in record_attributes[0]: + if len(record_attributes) > 1 and "entity" in record_attributes[0]: + logger.warning( + f"{chunk_key}: LLM output format error; found {len(record_attributes)}/4 feilds on ENTITY `{record_attributes[1]}` @ `{record_attributes[2] if len(record_attributes) > 2 else 'N/A'}`" + ) + logger.debug(record_attributes) + return None + + try: + entity_name = sanitize_and_normalize_extracted_text( + record_attributes[1], remove_inner_quotes=True + ) + + # Validate entity name after all cleaning steps + if not entity_name or not entity_name.strip(): + logger.info( + f"Empty entity name found after sanitization. Original: '{record_attributes[1]}'" + ) + return None + + # Process entity type with same cleaning pipeline + entity_type = sanitize_and_normalize_extracted_text( + record_attributes[2], remove_inner_quotes=True + ) + + if not entity_type.strip() or any( + char in entity_type for char in ["'", "(", ")", "<", ">", "|", "/", "\\"] + ): + logger.warning( + f"Entity extraction error: invalid entity type in: {record_attributes}" + ) + return None + + # Remove spaces and convert to lowercase + entity_type = entity_type.replace(" ", "").lower() + + # Process entity description with same cleaning pipeline + entity_description = sanitize_and_normalize_extracted_text(record_attributes[3]) + + if not entity_description.strip(): + logger.warning( + f"Entity extraction error: empty description for entity '{entity_name}' of type '{entity_type}'" + ) + return None + + return dict( + entity_name=entity_name, + entity_type=entity_type, + description=entity_description, + source_id=chunk_key, + file_path=file_path, + timestamp=timestamp, + ) + + except ValueError as e: + logger.error( + f"Entity extraction failed due to encoding issues in chunk {chunk_key}: {e}" + ) + return None + except Exception as e: + logger.error( + f"Entity extraction failed with unexpected error in chunk {chunk_key}: {e}" + ) + return None + + +async def _handle_single_relationship_extraction( + record_attributes: list[str], + chunk_key: str, + timestamp: int, + file_path: str = "unknown_source", +): + if ( + len(record_attributes) != 5 or "relation" not in record_attributes[0] + ): # treat "relationship" and "relation" interchangeable + if len(record_attributes) > 1 and "relation" in record_attributes[0]: + logger.warning( + f"{chunk_key}: LLM output format error; found {len(record_attributes)}/5 fields on REALTION `{record_attributes[1]}`~`{record_attributes[2] if len(record_attributes) > 2 else 'N/A'}`" + ) + logger.debug(record_attributes) + return None + + try: + source = sanitize_and_normalize_extracted_text( + record_attributes[1], remove_inner_quotes=True + ) + target = sanitize_and_normalize_extracted_text( + record_attributes[2], remove_inner_quotes=True + ) + + # Validate entity names after all cleaning steps + if not source: + logger.info( + f"Empty source entity found after sanitization. Original: '{record_attributes[1]}'" + ) + return None + + if not target: + logger.info( + f"Empty target entity found after sanitization. Original: '{record_attributes[2]}'" + ) + return None + + if source == target: + logger.debug( + f"Relationship source and target are the same in: {record_attributes}" + ) + return None + + # Process keywords with same cleaning pipeline + edge_keywords = sanitize_and_normalize_extracted_text( + record_attributes[3], remove_inner_quotes=True + ) + edge_keywords = edge_keywords.replace(",", ",") + + # Process relationship description with same cleaning pipeline + edge_description = sanitize_and_normalize_extracted_text(record_attributes[4]) + + edge_source_id = chunk_key + weight = ( + float(record_attributes[-1].strip('"').strip("'")) + if is_float_regex(record_attributes[-1].strip('"').strip("'")) + else 1.0 + ) + + return dict( + src_id=source, + tgt_id=target, + weight=weight, + description=edge_description, + keywords=edge_keywords, + source_id=edge_source_id, + file_path=file_path, + timestamp=timestamp, + ) + + except ValueError as e: + logger.warning( + f"Relationship extraction failed due to encoding issues in chunk {chunk_key}: {e}" + ) + return None + except Exception as e: + logger.warning( + f"Relationship extraction failed with unexpected error in chunk {chunk_key}: {e}" + ) + return None + + +async def rebuild_knowledge_from_chunks( + entities_to_rebuild: dict[str, list[str]], + relationships_to_rebuild: dict[tuple[str, str], list[str]], + knowledge_graph_inst: BaseGraphStorage, + entities_vdb: BaseVectorStorage, + relationships_vdb: BaseVectorStorage, + text_chunks_storage: BaseKVStorage, + llm_response_cache: BaseKVStorage, + global_config: dict[str, str], + pipeline_status: dict | None = None, + pipeline_status_lock=None, + entity_chunks_storage: BaseKVStorage | None = None, + relation_chunks_storage: BaseKVStorage | None = None, +) -> None: + """Rebuild entity and relationship descriptions from cached extraction results with parallel processing + + This method uses cached LLM extraction results instead of calling LLM again, + following the same approach as the insert process. Now with parallel processing + controlled by llm_model_max_async and using get_storage_keyed_lock for data consistency. + + Args: + entities_to_rebuild: Dict mapping entity_name -> list of remaining chunk_ids + relationships_to_rebuild: Dict mapping (src, tgt) -> list of remaining chunk_ids + knowledge_graph_inst: Knowledge graph storage + entities_vdb: Entity vector database + relationships_vdb: Relationship vector database + text_chunks_storage: Text chunks storage + llm_response_cache: LLM response cache + global_config: Global configuration containing llm_model_max_async + pipeline_status: Pipeline status dictionary + pipeline_status_lock: Lock for pipeline status + entity_chunks_storage: KV storage maintaining full chunk IDs per entity + relation_chunks_storage: KV storage maintaining full chunk IDs per relation + """ + if not entities_to_rebuild and not relationships_to_rebuild: + return + + # Get all referenced chunk IDs + all_referenced_chunk_ids = set() + for chunk_ids in entities_to_rebuild.values(): + all_referenced_chunk_ids.update(chunk_ids) + for chunk_ids in relationships_to_rebuild.values(): + all_referenced_chunk_ids.update(chunk_ids) + + status_message = f"Rebuilding knowledge from {len(all_referenced_chunk_ids)} cached chunk extractions (parallel processing)" + logger.info(status_message) + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = status_message + pipeline_status["history_messages"].append(status_message) + + # Get cached extraction results for these chunks using storage + # cached_results: chunk_id -> [list of (extraction_result, create_time) from LLM cache sorted by create_time of the first extraction_result] + cached_results = await _get_cached_extraction_results( + llm_response_cache, + all_referenced_chunk_ids, + text_chunks_storage=text_chunks_storage, + ) + + if not cached_results: + status_message = "No cached extraction results found, cannot rebuild" + logger.warning(status_message) + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = status_message + pipeline_status["history_messages"].append(status_message) + return + + # Process cached results to get entities and relationships for each chunk + chunk_entities = {} # chunk_id -> {entity_name: [entity_data]} + chunk_relationships = {} # chunk_id -> {(src, tgt): [relationship_data]} + + for chunk_id, results in cached_results.items(): + try: + # Handle multiple extraction results per chunk + chunk_entities[chunk_id] = defaultdict(list) + chunk_relationships[chunk_id] = defaultdict(list) + + # process multiple LLM extraction results for a single chunk_id + for result in results: + entities, relationships = await _rebuild_from_extraction_result( + text_chunks_storage=text_chunks_storage, + chunk_id=chunk_id, + extraction_result=result[0], + timestamp=result[1], + ) + + # Merge entities and relationships from this extraction result + # Compare description lengths and keep the better version for the same chunk_id + for entity_name, entity_list in entities.items(): + if entity_name not in chunk_entities[chunk_id]: + # New entity for this chunk_id + chunk_entities[chunk_id][entity_name].extend(entity_list) + elif len(chunk_entities[chunk_id][entity_name]) == 0: + # Empty list, add the new entities + chunk_entities[chunk_id][entity_name].extend(entity_list) + else: + # Compare description lengths and keep the better one + existing_desc_len = len( + chunk_entities[chunk_id][entity_name][0].get( + "description", "" + ) + or "" + ) + new_desc_len = len(entity_list[0].get("description", "") or "") + + if new_desc_len > existing_desc_len: + # Replace with the new entity that has longer description + chunk_entities[chunk_id][entity_name] = list(entity_list) + # Otherwise keep existing version + + # Compare description lengths and keep the better version for the same chunk_id + for rel_key, rel_list in relationships.items(): + if rel_key not in chunk_relationships[chunk_id]: + # New relationship for this chunk_id + chunk_relationships[chunk_id][rel_key].extend(rel_list) + elif len(chunk_relationships[chunk_id][rel_key]) == 0: + # Empty list, add the new relationships + chunk_relationships[chunk_id][rel_key].extend(rel_list) + else: + # Compare description lengths and keep the better one + existing_desc_len = len( + chunk_relationships[chunk_id][rel_key][0].get( + "description", "" + ) + or "" + ) + new_desc_len = len(rel_list[0].get("description", "") or "") + + if new_desc_len > existing_desc_len: + # Replace with the new relationship that has longer description + chunk_relationships[chunk_id][rel_key] = list(rel_list) + # Otherwise keep existing version + + except Exception as e: + status_message = ( + f"Failed to parse cached extraction result for chunk {chunk_id}: {e}" + ) + logger.info(status_message) # Per requirement, change to info + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = status_message + pipeline_status["history_messages"].append(status_message) + continue + + # Get max async tasks limit from global_config for semaphore control + graph_max_async = global_config.get("llm_model_max_async", 4) * 2 + semaphore = asyncio.Semaphore(graph_max_async) + + # Counters for tracking progress + rebuilt_entities_count = 0 + rebuilt_relationships_count = 0 + failed_entities_count = 0 + failed_relationships_count = 0 + + async def _locked_rebuild_entity(entity_name, chunk_ids): + nonlocal rebuilt_entities_count, failed_entities_count + async with semaphore: + workspace = global_config.get("workspace", "") + namespace = f"{workspace}:GraphDB" if workspace else "GraphDB" + async with get_storage_keyed_lock( + [entity_name], namespace=namespace, enable_logging=False + ): + try: + await _rebuild_single_entity( + knowledge_graph_inst=knowledge_graph_inst, + entities_vdb=entities_vdb, + entity_name=entity_name, + chunk_ids=chunk_ids, + chunk_entities=chunk_entities, + llm_response_cache=llm_response_cache, + global_config=global_config, + entity_chunks_storage=entity_chunks_storage, + ) + rebuilt_entities_count += 1 + except Exception as e: + failed_entities_count += 1 + status_message = f"Failed to rebuild `{entity_name}`: {e}" + logger.info(status_message) # Per requirement, change to info + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = status_message + pipeline_status["history_messages"].append(status_message) + + async def _locked_rebuild_relationship(src, tgt, chunk_ids): + nonlocal rebuilt_relationships_count, failed_relationships_count + async with semaphore: + workspace = global_config.get("workspace", "") + namespace = f"{workspace}:GraphDB" if workspace else "GraphDB" + # Sort src and tgt to ensure order-independent lock key generation + sorted_key_parts = sorted([src, tgt]) + async with get_storage_keyed_lock( + sorted_key_parts, + namespace=namespace, + enable_logging=False, + ): + try: + await _rebuild_single_relationship( + knowledge_graph_inst=knowledge_graph_inst, + relationships_vdb=relationships_vdb, + entities_vdb=entities_vdb, + src=src, + tgt=tgt, + chunk_ids=chunk_ids, + chunk_relationships=chunk_relationships, + llm_response_cache=llm_response_cache, + global_config=global_config, + relation_chunks_storage=relation_chunks_storage, + entity_chunks_storage=entity_chunks_storage, + pipeline_status=pipeline_status, + pipeline_status_lock=pipeline_status_lock, + ) + rebuilt_relationships_count += 1 + except Exception as e: + failed_relationships_count += 1 + status_message = f"Failed to rebuild `{src}`~`{tgt}`: {e}" + logger.info(status_message) # Per requirement, change to info + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = status_message + pipeline_status["history_messages"].append(status_message) + + # Create tasks for parallel processing + tasks = [] + + # Add entity rebuilding tasks + for entity_name, chunk_ids in entities_to_rebuild.items(): + task = asyncio.create_task(_locked_rebuild_entity(entity_name, chunk_ids)) + tasks.append(task) + + # Add relationship rebuilding tasks + for (src, tgt), chunk_ids in relationships_to_rebuild.items(): + task = asyncio.create_task(_locked_rebuild_relationship(src, tgt, chunk_ids)) + tasks.append(task) + + # Log parallel processing start + status_message = f"Starting parallel rebuild of {len(entities_to_rebuild)} entities and {len(relationships_to_rebuild)} relationships (async: {graph_max_async})" + logger.info(status_message) + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = status_message + pipeline_status["history_messages"].append(status_message) + + # Execute all tasks in parallel with semaphore control and early failure detection + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION) + + # Check if any task raised an exception and ensure all exceptions are retrieved + first_exception = None + + for task in done: + try: + exception = task.exception() + if exception is not None: + if first_exception is None: + first_exception = exception + else: + # Task completed successfully, retrieve result to mark as processed + task.result() + except Exception as e: + if first_exception is None: + first_exception = e + + # If any task failed, cancel all pending tasks and raise the first exception + if first_exception is not None: + # Cancel all pending tasks + for pending_task in pending: + pending_task.cancel() + + # Wait for cancellation to complete + if pending: + await asyncio.wait(pending) + + # Re-raise the first exception to notify the caller + raise first_exception + + # Final status report + status_message = f"KG rebuild completed: {rebuilt_entities_count} entities and {rebuilt_relationships_count} relationships rebuilt successfully." + if failed_entities_count > 0 or failed_relationships_count > 0: + status_message += f" Failed: {failed_entities_count} entities, {failed_relationships_count} relationships." + + logger.info(status_message) + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = status_message + pipeline_status["history_messages"].append(status_message) + + +async def _get_cached_extraction_results( + llm_response_cache: BaseKVStorage, + chunk_ids: set[str], + text_chunks_storage: BaseKVStorage, +) -> dict[str, list[str]]: + """Get cached extraction results for specific chunk IDs + + This function retrieves cached LLM extraction results for the given chunk IDs and returns + them sorted by creation time. The results are sorted at two levels: + 1. Individual extraction results within each chunk are sorted by create_time (earliest first) + 2. Chunks themselves are sorted by the create_time of their earliest extraction result + + Args: + llm_response_cache: LLM response cache storage + chunk_ids: Set of chunk IDs to get cached results for + text_chunks_storage: Text chunks storage for retrieving chunk data and LLM cache references + + Returns: + Dict mapping chunk_id -> list of extraction_result_text, where: + - Keys (chunk_ids) are ordered by the create_time of their first extraction result + - Values (extraction results) are ordered by create_time within each chunk + """ + cached_results = {} + + # Collect all LLM cache IDs from chunks + all_cache_ids = set() + + # Read from storage + chunk_data_list = await text_chunks_storage.get_by_ids(list(chunk_ids)) + for chunk_data in chunk_data_list: + if chunk_data and isinstance(chunk_data, dict): + llm_cache_list = chunk_data.get("llm_cache_list", []) + if llm_cache_list: + all_cache_ids.update(llm_cache_list) + else: + logger.warning(f"Chunk data is invalid or None: {chunk_data}") + + if not all_cache_ids: + logger.warning(f"No LLM cache IDs found for {len(chunk_ids)} chunk IDs") + return cached_results + + # Batch get LLM cache entries + cache_data_list = await llm_response_cache.get_by_ids(list(all_cache_ids)) + + # Process cache entries and group by chunk_id + valid_entries = 0 + for cache_entry in cache_data_list: + if ( + cache_entry is not None + and isinstance(cache_entry, dict) + and cache_entry.get("cache_type") == "extract" + and cache_entry.get("chunk_id") in chunk_ids + ): + chunk_id = cache_entry["chunk_id"] + extraction_result = cache_entry["return"] + create_time = cache_entry.get( + "create_time", 0 + ) # Get creation time, default to 0 + valid_entries += 1 + + # Support multiple LLM caches per chunk + if chunk_id not in cached_results: + cached_results[chunk_id] = [] + # Store tuple with extraction result and creation time for sorting + cached_results[chunk_id].append((extraction_result, create_time)) + + # Sort extraction results by create_time for each chunk and collect earliest times + chunk_earliest_times = {} + for chunk_id in cached_results: + # Sort by create_time (x[1]), then extract only extraction_result (x[0]) + cached_results[chunk_id].sort(key=lambda x: x[1]) + # Store the earliest create_time for this chunk (first item after sorting) + chunk_earliest_times[chunk_id] = cached_results[chunk_id][0][1] + + # Sort cached_results by the earliest create_time of each chunk + sorted_chunk_ids = sorted( + chunk_earliest_times.keys(), key=lambda chunk_id: chunk_earliest_times[chunk_id] + ) + + # Rebuild cached_results in sorted order + sorted_cached_results = {} + for chunk_id in sorted_chunk_ids: + sorted_cached_results[chunk_id] = cached_results[chunk_id] + + logger.info( + f"Found {valid_entries} valid cache entries, {len(sorted_cached_results)} chunks with results" + ) + return sorted_cached_results # each item: list(extraction_result, create_time) + + +async def _process_extraction_result( + result: str, + chunk_key: str, + timestamp: int, + file_path: str = "unknown_source", + tuple_delimiter: str = "<|#|>", + completion_delimiter: str = "<|COMPLETE|>", +) -> tuple[dict, dict]: + """Process a single extraction result (either initial or gleaning) + Args: + result (str): The extraction result to process + chunk_key (str): The chunk key for source tracking + file_path (str): The file path for citation + tuple_delimiter (str): Delimiter for tuple fields + record_delimiter (str): Delimiter for records + completion_delimiter (str): Delimiter for completion + Returns: + tuple: (nodes_dict, edges_dict) containing the extracted entities and relationships + """ + maybe_nodes = defaultdict(list) + maybe_edges = defaultdict(list) + + if completion_delimiter not in result: + logger.warning( + f"{chunk_key}: Complete delimiter can not be found in extraction result" + ) + + # Split LLL output result to records by "\n" + records = split_string_by_multi_markers( + result, + ["\n", completion_delimiter, completion_delimiter.lower()], + ) + + # Fix LLM output format error which use tuple_delimiter to separate record instead of "\n" + fixed_records = [] + for record in records: + record = record.strip() + if record is None: + continue + entity_records = split_string_by_multi_markers( + record, [f"{tuple_delimiter}entity{tuple_delimiter}"] + ) + for entity_record in entity_records: + if not entity_record.startswith("entity") and not entity_record.startswith( + "relation" + ): + entity_record = f"entity<|{entity_record}" + entity_relation_records = split_string_by_multi_markers( + # treat "relationship" and "relation" interchangeable + entity_record, + [ + f"{tuple_delimiter}relationship{tuple_delimiter}", + f"{tuple_delimiter}relation{tuple_delimiter}", + ], + ) + for entity_relation_record in entity_relation_records: + if not entity_relation_record.startswith( + "entity" + ) and not entity_relation_record.startswith("relation"): + entity_relation_record = ( + f"relation{tuple_delimiter}{entity_relation_record}" + ) + fixed_records = fixed_records + [entity_relation_record] + + if len(fixed_records) != len(records): + logger.warning( + f"{chunk_key}: LLM output format error; find LLM use {tuple_delimiter} as record separators instead new-line" + ) + + for record in fixed_records: + record = record.strip() + if record is None: + continue + + # Fix various forms of tuple_delimiter corruption from the LLM output using the dedicated function + delimiter_core = tuple_delimiter[2:-2] # Extract "#" from "<|#|>" + record = fix_tuple_delimiter_corruption(record, delimiter_core, tuple_delimiter) + if delimiter_core != delimiter_core.lower(): + # change delimiter_core to lower case, and fix again + delimiter_core = delimiter_core.lower() + record = fix_tuple_delimiter_corruption( + record, delimiter_core, tuple_delimiter + ) + + record_attributes = split_string_by_multi_markers(record, [tuple_delimiter]) + + # Try to parse as entity + entity_data = await _handle_single_entity_extraction( + record_attributes, chunk_key, timestamp, file_path + ) + if entity_data is not None: + truncated_name = _truncate_entity_identifier( + entity_data["entity_name"], + DEFAULT_ENTITY_NAME_MAX_LENGTH, + chunk_key, + "Entity name", + ) + entity_data["entity_name"] = truncated_name + maybe_nodes[truncated_name].append(entity_data) + continue + + # Try to parse as relationship + relationship_data = await _handle_single_relationship_extraction( + record_attributes, chunk_key, timestamp, file_path + ) + if relationship_data is not None: + truncated_source = _truncate_entity_identifier( + relationship_data["src_id"], + DEFAULT_ENTITY_NAME_MAX_LENGTH, + chunk_key, + "Relation entity", + ) + truncated_target = _truncate_entity_identifier( + relationship_data["tgt_id"], + DEFAULT_ENTITY_NAME_MAX_LENGTH, + chunk_key, + "Relation entity", + ) + relationship_data["src_id"] = truncated_source + relationship_data["tgt_id"] = truncated_target + maybe_edges[(truncated_source, truncated_target)].append(relationship_data) + + return dict(maybe_nodes), dict(maybe_edges) + + +async def _rebuild_from_extraction_result( + text_chunks_storage: BaseKVStorage, + extraction_result: str, + chunk_id: str, + timestamp: int, +) -> tuple[dict, dict]: + """Parse cached extraction result using the same logic as extract_entities + + Args: + text_chunks_storage: Text chunks storage to get chunk data + extraction_result: The cached LLM extraction result + chunk_id: The chunk ID for source tracking + + Returns: + Tuple of (entities_dict, relationships_dict) + """ + + # Get chunk data for file_path from storage + chunk_data = await text_chunks_storage.get_by_id(chunk_id) + file_path = ( + chunk_data.get("file_path", "unknown_source") + if chunk_data + else "unknown_source" + ) + + # Call the shared processing function + return await _process_extraction_result( + extraction_result, + chunk_id, + timestamp, + file_path, + tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"], + completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"], + ) + + +async def _rebuild_single_entity( + knowledge_graph_inst: BaseGraphStorage, + entities_vdb: BaseVectorStorage, + entity_name: str, + chunk_ids: list[str], + chunk_entities: dict, + llm_response_cache: BaseKVStorage, + global_config: dict[str, str], + entity_chunks_storage: BaseKVStorage | None = None, + pipeline_status: dict | None = None, + pipeline_status_lock=None, +) -> None: + """Rebuild a single entity from cached extraction results""" + + # Get current entity data + current_entity = await knowledge_graph_inst.get_node(entity_name) + if not current_entity: + return + + # Helper function to update entity in both graph and vector storage + async def _update_entity_storage( + final_description: str, + entity_type: str, + file_paths: list[str], + source_chunk_ids: list[str], + truncation_info: str = "", + ): + try: + # Update entity in graph storage (critical path) + updated_entity_data = { + **current_entity, + "description": final_description, + "entity_type": entity_type, + "source_id": GRAPH_FIELD_SEP.join(source_chunk_ids), + "file_path": GRAPH_FIELD_SEP.join(file_paths) + if file_paths + else current_entity.get("file_path", "unknown_source"), + "created_at": int(time.time()), + "truncate": truncation_info, + } + await knowledge_graph_inst.upsert_node(entity_name, updated_entity_data) + + # Update entity in vector database (equally critical) + entity_vdb_id = compute_mdhash_id(entity_name, prefix="ent-") + entity_content = f"{entity_name}\n{final_description}" + + vdb_data = { + entity_vdb_id: { + "content": entity_content, + "entity_name": entity_name, + "source_id": updated_entity_data["source_id"], + "description": final_description, + "entity_type": entity_type, + "file_path": updated_entity_data["file_path"], + } + } + + # Use safe operation wrapper - VDB failure must throw exception + await safe_vdb_operation_with_exception( + operation=lambda: entities_vdb.upsert(vdb_data), + operation_name="rebuild_entity_upsert", + entity_name=entity_name, + max_retries=3, + retry_delay=0.1, + ) + + except Exception as e: + error_msg = f"Failed to update entity storage for `{entity_name}`: {e}" + logger.error(error_msg) + raise # Re-raise exception + + # normalized_chunk_ids = merge_source_ids([], chunk_ids) + normalized_chunk_ids = chunk_ids + + if entity_chunks_storage is not None and normalized_chunk_ids: + await entity_chunks_storage.upsert( + { + entity_name: { + "chunk_ids": normalized_chunk_ids, + "count": len(normalized_chunk_ids), + } + } + ) + + limit_method = ( + global_config.get("source_ids_limit_method") or SOURCE_IDS_LIMIT_METHOD_KEEP + ) + + limited_chunk_ids = apply_source_ids_limit( + normalized_chunk_ids, + global_config["max_source_ids_per_entity"], + limit_method, + identifier=f"`{entity_name}`", + ) + + # Collect all entity data from relevant (limited) chunks + all_entity_data = [] + for chunk_id in limited_chunk_ids: + if chunk_id in chunk_entities and entity_name in chunk_entities[chunk_id]: + all_entity_data.extend(chunk_entities[chunk_id][entity_name]) + + if not all_entity_data: + logger.warning( + f"No entity data found for `{entity_name}`, trying to rebuild from relationships" + ) + + # Get all edges connected to this entity + edges = await knowledge_graph_inst.get_node_edges(entity_name) + if not edges: + logger.warning(f"No relations attached to entity `{entity_name}`") + return + + # Collect relationship data to extract entity information + relationship_descriptions = [] + file_paths = set() + + # Get edge data for all connected relationships + for src_id, tgt_id in edges: + edge_data = await knowledge_graph_inst.get_edge(src_id, tgt_id) + if edge_data: + if edge_data.get("description"): + relationship_descriptions.append(edge_data["description"]) + + if edge_data.get("file_path"): + edge_file_paths = edge_data["file_path"].split(GRAPH_FIELD_SEP) + file_paths.update(edge_file_paths) + + # deduplicate descriptions + description_list = list(dict.fromkeys(relationship_descriptions)) + + # Generate final description from relationships or fallback to current + if description_list: + final_description, _ = await _handle_entity_relation_summary( + "Entity", + entity_name, + description_list, + GRAPH_FIELD_SEP, + global_config, + llm_response_cache=llm_response_cache, + ) + else: + final_description = current_entity.get("description", "") + + entity_type = current_entity.get("entity_type", "UNKNOWN") + await _update_entity_storage( + final_description, + entity_type, + file_paths, + limited_chunk_ids, + ) + return + + # Process cached entity data + descriptions = [] + entity_types = [] + file_paths_list = [] + seen_paths = set() + + for entity_data in all_entity_data: + if entity_data.get("description"): + descriptions.append(entity_data["description"]) + if entity_data.get("entity_type"): + entity_types.append(entity_data["entity_type"]) + if entity_data.get("file_path"): + file_path = entity_data["file_path"] + if file_path and file_path not in seen_paths: + file_paths_list.append(file_path) + seen_paths.add(file_path) + + # Apply MAX_FILE_PATHS limit + max_file_paths = global_config.get("max_file_paths", DEFAULT_MAX_FILE_PATHS) + file_path_placeholder = global_config.get( + "file_path_more_placeholder", DEFAULT_FILE_PATH_MORE_PLACEHOLDER + ) + limit_method = global_config.get("source_ids_limit_method") + + original_count = len(file_paths_list) + if original_count > max_file_paths: + if limit_method == SOURCE_IDS_LIMIT_METHOD_FIFO: + # FIFO: keep tail (newest), discard head + file_paths_list = file_paths_list[-max_file_paths:] + else: + # KEEP: keep head (earliest), discard tail + file_paths_list = file_paths_list[:max_file_paths] + + file_paths_list.append( + f"...{file_path_placeholder}...({limit_method} {max_file_paths}/{original_count})" + ) + logger.info( + f"Limited `{entity_name}`: file_path {original_count} -> {max_file_paths} ({limit_method})" + ) + + # Remove duplicates while preserving order + description_list = list(dict.fromkeys(descriptions)) + entity_types = list(dict.fromkeys(entity_types)) + + # Get most common entity type + entity_type = ( + max(set(entity_types), key=entity_types.count) + if entity_types + else current_entity.get("entity_type", "UNKNOWN") + ) + + # Generate final description from entities or fallback to current + if description_list: + final_description, _ = await _handle_entity_relation_summary( + "Entity", + entity_name, + description_list, + GRAPH_FIELD_SEP, + global_config, + llm_response_cache=llm_response_cache, + ) + else: + final_description = current_entity.get("description", "") + + if len(limited_chunk_ids) < len(normalized_chunk_ids): + truncation_info = ( + f"{limit_method} {len(limited_chunk_ids)}/{len(normalized_chunk_ids)}" + ) + else: + truncation_info = "" + + await _update_entity_storage( + final_description, + entity_type, + file_paths_list, + limited_chunk_ids, + truncation_info, + ) + + # Log rebuild completion with truncation info + status_message = f"Rebuild `{entity_name}` from {len(chunk_ids)} chunks" + if truncation_info: + status_message += f" ({truncation_info})" + logger.info(status_message) + # Update pipeline status + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = status_message + pipeline_status["history_messages"].append(status_message) + + +async def _rebuild_single_relationship( + knowledge_graph_inst: BaseGraphStorage, + relationships_vdb: BaseVectorStorage, + entities_vdb: BaseVectorStorage, + src: str, + tgt: str, + chunk_ids: list[str], + chunk_relationships: dict, + llm_response_cache: BaseKVStorage, + global_config: dict[str, str], + relation_chunks_storage: BaseKVStorage | None = None, + entity_chunks_storage: BaseKVStorage | None = None, + pipeline_status: dict | None = None, + pipeline_status_lock=None, +) -> None: + """Rebuild a single relationship from cached extraction results + + Note: This function assumes the caller has already acquired the appropriate + keyed lock for the relationship pair to ensure thread safety. + """ + + # Get current relationship data + current_relationship = await knowledge_graph_inst.get_edge(src, tgt) + if not current_relationship: + return + + # normalized_chunk_ids = merge_source_ids([], chunk_ids) + normalized_chunk_ids = chunk_ids + + if relation_chunks_storage is not None and normalized_chunk_ids: + storage_key = make_relation_chunk_key(src, tgt) + await relation_chunks_storage.upsert( + { + storage_key: { + "chunk_ids": normalized_chunk_ids, + "count": len(normalized_chunk_ids), + } + } + ) + + limit_method = ( + global_config.get("source_ids_limit_method") or SOURCE_IDS_LIMIT_METHOD_KEEP + ) + limited_chunk_ids = apply_source_ids_limit( + normalized_chunk_ids, + global_config["max_source_ids_per_relation"], + limit_method, + identifier=f"`{src}`~`{tgt}`", + ) + + # Collect all relationship data from relevant chunks + all_relationship_data = [] + for chunk_id in limited_chunk_ids: + if chunk_id in chunk_relationships: + # Check both (src, tgt) and (tgt, src) since relationships can be bidirectional + for edge_key in [(src, tgt), (tgt, src)]: + if edge_key in chunk_relationships[chunk_id]: + all_relationship_data.extend( + chunk_relationships[chunk_id][edge_key] + ) + + if not all_relationship_data: + logger.warning(f"No relation data found for `{src}-{tgt}`") + return + + # Merge descriptions and keywords + descriptions = [] + keywords = [] + weights = [] + file_paths_list = [] + seen_paths = set() + + for rel_data in all_relationship_data: + if rel_data.get("description"): + descriptions.append(rel_data["description"]) + if rel_data.get("keywords"): + keywords.append(rel_data["keywords"]) + if rel_data.get("weight"): + weights.append(rel_data["weight"]) + if rel_data.get("file_path"): + file_path = rel_data["file_path"] + if file_path and file_path not in seen_paths: + file_paths_list.append(file_path) + seen_paths.add(file_path) + + # Apply count limit + max_file_paths = global_config.get("max_file_paths", DEFAULT_MAX_FILE_PATHS) + file_path_placeholder = global_config.get( + "file_path_more_placeholder", DEFAULT_FILE_PATH_MORE_PLACEHOLDER + ) + limit_method = global_config.get("source_ids_limit_method") + + original_count = len(file_paths_list) + if original_count > max_file_paths: + if limit_method == SOURCE_IDS_LIMIT_METHOD_FIFO: + # FIFO: keep tail (newest), discard head + file_paths_list = file_paths_list[-max_file_paths:] + else: + # KEEP: keep head (earliest), discard tail + file_paths_list = file_paths_list[:max_file_paths] + + file_paths_list.append( + f"...{file_path_placeholder}...({limit_method} {max_file_paths}/{original_count})" + ) + logger.info( + f"Limited `{src}`~`{tgt}`: file_path {original_count} -> {max_file_paths} ({limit_method})" + ) + + # Remove duplicates while preserving order + description_list = list(dict.fromkeys(descriptions)) + keywords = list(dict.fromkeys(keywords)) + + combined_keywords = ( + ", ".join(set(keywords)) + if keywords + else current_relationship.get("keywords", "") + ) + + weight = sum(weights) if weights else current_relationship.get("weight", 1.0) + + # Generate final description from relations or fallback to current + if description_list: + final_description, _ = await _handle_entity_relation_summary( + "Relation", + f"{src}-{tgt}", + description_list, + GRAPH_FIELD_SEP, + global_config, + llm_response_cache=llm_response_cache, + ) + else: + # fallback to keep current(unchanged) + final_description = current_relationship.get("description", "") + + if len(limited_chunk_ids) < len(normalized_chunk_ids): + truncation_info = ( + f"{limit_method} {len(limited_chunk_ids)}/{len(normalized_chunk_ids)}" + ) + else: + truncation_info = "" + + # Update relationship in graph storage + updated_relationship_data = { + **current_relationship, + "description": final_description + if final_description + else current_relationship.get("description", ""), + "keywords": combined_keywords, + "weight": weight, + "source_id": GRAPH_FIELD_SEP.join(limited_chunk_ids), + "file_path": GRAPH_FIELD_SEP.join([fp for fp in file_paths_list if fp]) + if file_paths_list + else current_relationship.get("file_path", "unknown_source"), + "truncate": truncation_info, + } + + # Ensure both endpoint nodes exist before writing the edge back + # (certain storage backends require pre-existing nodes). + node_description = ( + updated_relationship_data["description"] + if updated_relationship_data.get("description") + else current_relationship.get("description", "") + ) + node_source_id = updated_relationship_data.get("source_id", "") + node_file_path = updated_relationship_data.get("file_path", "unknown_source") + + for node_id in {src, tgt}: + if not (await knowledge_graph_inst.has_node(node_id)): + node_created_at = int(time.time()) + node_data = { + "entity_id": node_id, + "source_id": node_source_id, + "description": node_description, + "entity_type": "UNKNOWN", + "file_path": node_file_path, + "created_at": node_created_at, + "truncate": "", + } + await knowledge_graph_inst.upsert_node(node_id, node_data=node_data) + + # Update entity_chunks_storage for the newly created entity + if entity_chunks_storage is not None and limited_chunk_ids: + await entity_chunks_storage.upsert( + { + node_id: { + "chunk_ids": limited_chunk_ids, + "count": len(limited_chunk_ids), + } + } + ) + + # Update entity_vdb for the newly created entity + if entities_vdb is not None: + entity_vdb_id = compute_mdhash_id(node_id, prefix="ent-") + entity_content = f"{node_id}\n{node_description}" + vdb_data = { + entity_vdb_id: { + "content": entity_content, + "entity_name": node_id, + "source_id": node_source_id, + "entity_type": "UNKNOWN", + "file_path": node_file_path, + } + } + await safe_vdb_operation_with_exception( + operation=lambda payload=vdb_data: entities_vdb.upsert(payload), + operation_name="rebuild_added_entity_upsert", + entity_name=node_id, + max_retries=3, + retry_delay=0.1, + ) + + await knowledge_graph_inst.upsert_edge(src, tgt, updated_relationship_data) + + # Update relationship in vector database + # Sort src and tgt to ensure consistent ordering (smaller string first) + if src > tgt: + src, tgt = tgt, src + try: + rel_vdb_id = compute_mdhash_id(src + tgt, prefix="rel-") + rel_vdb_id_reverse = compute_mdhash_id(tgt + src, prefix="rel-") + + # Delete old vector records first (both directions to be safe) + try: + await relationships_vdb.delete([rel_vdb_id, rel_vdb_id_reverse]) + except Exception as e: + logger.debug( + f"Could not delete old relationship vector records {rel_vdb_id}, {rel_vdb_id_reverse}: {e}" + ) + + # Insert new vector record + rel_content = f"{combined_keywords}\t{src}\n{tgt}\n{final_description}" + vdb_data = { + rel_vdb_id: { + "src_id": src, + "tgt_id": tgt, + "source_id": updated_relationship_data["source_id"], + "content": rel_content, + "keywords": combined_keywords, + "description": final_description, + "weight": weight, + "file_path": updated_relationship_data["file_path"], + } + } + + # Use safe operation wrapper - VDB failure must throw exception + await safe_vdb_operation_with_exception( + operation=lambda: relationships_vdb.upsert(vdb_data), + operation_name="rebuild_relationship_upsert", + entity_name=f"{src}-{tgt}", + max_retries=3, + retry_delay=0.2, + ) + + except Exception as e: + error_msg = f"Failed to rebuild relationship storage for `{src}-{tgt}`: {e}" + logger.error(error_msg) + raise # Re-raise exception + + # Log rebuild completion with truncation info + status_message = f"Rebuild `{src}`~`{tgt}` from {len(chunk_ids)} chunks" + if truncation_info: + status_message += f" ({truncation_info})" + # Add truncation info from apply_source_ids_limit if truncation occurred + if len(limited_chunk_ids) < len(normalized_chunk_ids): + truncation_info = ( + f" ({limit_method}:{len(limited_chunk_ids)}/{len(normalized_chunk_ids)})" + ) + status_message += truncation_info + + logger.info(status_message) + + # Update pipeline status + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = status_message + pipeline_status["history_messages"].append(status_message) + + +async def _merge_nodes_then_upsert( + entity_name: str, + nodes_data: list[dict], + knowledge_graph_inst: BaseGraphStorage, + entity_vdb: BaseVectorStorage | None, + global_config: dict, + pipeline_status: dict = None, + pipeline_status_lock=None, + llm_response_cache: BaseKVStorage | None = None, + entity_chunks_storage: BaseKVStorage | None = None, +): + """Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert.""" + already_entity_types = [] + already_source_ids = [] + already_description = [] + already_file_paths = [] + + # 1. Get existing node data from knowledge graph + already_node = await knowledge_graph_inst.get_node(entity_name) + if already_node: + already_entity_types.append(already_node["entity_type"]) + already_source_ids.extend(already_node["source_id"].split(GRAPH_FIELD_SEP)) + already_file_paths.extend(already_node["file_path"].split(GRAPH_FIELD_SEP)) + already_description.extend(already_node["description"].split(GRAPH_FIELD_SEP)) + + new_source_ids = [dp["source_id"] for dp in nodes_data if dp.get("source_id")] + + existing_full_source_ids = [] + if entity_chunks_storage is not None: + stored_chunks = await entity_chunks_storage.get_by_id(entity_name) + if stored_chunks and isinstance(stored_chunks, dict): + existing_full_source_ids = [ + chunk_id for chunk_id in stored_chunks.get("chunk_ids", []) if chunk_id + ] + + if not existing_full_source_ids: + existing_full_source_ids = [ + chunk_id for chunk_id in already_source_ids if chunk_id + ] + + # 2. Merging new source ids with existing ones + full_source_ids = merge_source_ids(existing_full_source_ids, new_source_ids) + + if entity_chunks_storage is not None and full_source_ids: + await entity_chunks_storage.upsert( + { + entity_name: { + "chunk_ids": full_source_ids, + "count": len(full_source_ids), + } + } + ) + + # 3. Finalize source_id by applying source ids limit + limit_method = global_config.get("source_ids_limit_method") + max_source_limit = global_config.get("max_source_ids_per_entity") + source_ids = apply_source_ids_limit( + full_source_ids, + max_source_limit, + limit_method, + identifier=f"`{entity_name}`", + ) + + # 4. Only keep nodes not filter by apply_source_ids_limit if limit_method is KEEP + if limit_method == SOURCE_IDS_LIMIT_METHOD_KEEP: + allowed_source_ids = set(source_ids) + filtered_nodes = [] + for dp in nodes_data: + source_id = dp.get("source_id") + # Skip descriptions sourced from chunks dropped by the limitation cap + if ( + source_id + and source_id not in allowed_source_ids + and source_id not in existing_full_source_ids + ): + continue + filtered_nodes.append(dp) + nodes_data = filtered_nodes + else: # In FIFO mode, keep all nodes - truncation happens at source_ids level only + nodes_data = list(nodes_data) + + # 5. Check if we need to skip summary due to source_ids limit + if ( + limit_method == SOURCE_IDS_LIMIT_METHOD_KEEP + and len(existing_full_source_ids) >= max_source_limit + and not nodes_data + ): + if already_node: + logger.info( + f"Skipped `{entity_name}`: KEEP old chunks {already_source_ids}/{len(full_source_ids)}" + ) + existing_node_data = dict(already_node) + return existing_node_data + else: + logger.error(f"Internal Error: already_node missing for `{entity_name}`") + raise ValueError( + f"Internal Error: already_node missing for `{entity_name}`" + ) + + # 6.1 Finalize source_id + source_id = GRAPH_FIELD_SEP.join(source_ids) + + # 6.2 Finalize entity type by highest count + entity_type = sorted( + Counter( + [dp["entity_type"] for dp in nodes_data] + already_entity_types + ).items(), + key=lambda x: x[1], + reverse=True, + )[0][0] + + # 7. Deduplicate nodes by description, keeping first occurrence in the same document + unique_nodes = {} + for dp in nodes_data: + desc = dp.get("description") + if not desc: + continue + if desc not in unique_nodes: + unique_nodes[desc] = dp + + # Sort description by timestamp, then by description length when timestamps are the same + sorted_nodes = sorted( + unique_nodes.values(), + key=lambda x: (x.get("timestamp", 0), -len(x.get("description", ""))), + ) + sorted_descriptions = [dp["description"] for dp in sorted_nodes] + + # Combine already_description with sorted new sorted descriptions + description_list = already_description + sorted_descriptions + if not description_list: + logger.error(f"Entity {entity_name} has no description") + raise ValueError(f"Entity {entity_name} has no description") + + # Check for cancellation before LLM summary + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + if pipeline_status.get("cancellation_requested", False): + raise PipelineCancelledException("User cancelled during entity summary") + + # 8. Get summary description an LLM usage status + description, llm_was_used = await _handle_entity_relation_summary( + "Entity", + entity_name, + description_list, + GRAPH_FIELD_SEP, + global_config, + llm_response_cache, + ) + + # 9. Build file_path within MAX_FILE_PATHS + file_paths_list = [] + seen_paths = set() + has_placeholder = False # Indicating file_path has been truncated before + + max_file_paths = global_config.get("max_file_paths", DEFAULT_MAX_FILE_PATHS) + file_path_placeholder = global_config.get( + "file_path_more_placeholder", DEFAULT_FILE_PATH_MORE_PLACEHOLDER + ) + + # Collect from already_file_paths, excluding placeholder + for fp in already_file_paths: + if fp and fp.startswith(f"...{file_path_placeholder}"): # Skip placeholders + has_placeholder = True + continue + if fp and fp not in seen_paths: + file_paths_list.append(fp) + seen_paths.add(fp) + + # Collect from new data + for dp in nodes_data: + file_path_item = dp.get("file_path") + if file_path_item and file_path_item not in seen_paths: + file_paths_list.append(file_path_item) + seen_paths.add(file_path_item) + + # Apply count limit + if len(file_paths_list) > max_file_paths: + limit_method = global_config.get( + "source_ids_limit_method", SOURCE_IDS_LIMIT_METHOD_KEEP + ) + file_path_placeholder = global_config.get( + "file_path_more_placeholder", DEFAULT_FILE_PATH_MORE_PLACEHOLDER + ) + # Add + sign to indicate actual file count is higher + original_count_str = ( + f"{len(file_paths_list)}+" if has_placeholder else str(len(file_paths_list)) + ) + + if limit_method == SOURCE_IDS_LIMIT_METHOD_FIFO: + # FIFO: keep tail (newest), discard head + file_paths_list = file_paths_list[-max_file_paths:] + file_paths_list.append(f"...{file_path_placeholder}...(FIFO)") + else: + # KEEP: keep head (earliest), discard tail + file_paths_list = file_paths_list[:max_file_paths] + file_paths_list.append(f"...{file_path_placeholder}...(KEEP Old)") + + logger.info( + f"Limited `{entity_name}`: file_path {original_count_str} -> {max_file_paths} ({limit_method})" + ) + # Finalize file_path + file_path = GRAPH_FIELD_SEP.join(file_paths_list) + + # 10.Log based on actual LLM usage + num_fragment = len(description_list) + already_fragment = len(already_description) + if llm_was_used: + status_message = f"LLMmrg: `{entity_name}` | {already_fragment}+{num_fragment - already_fragment}" + else: + status_message = f"Merged: `{entity_name}` | {already_fragment}+{num_fragment - already_fragment}" + + truncation_info = truncation_info_log = "" + if len(source_ids) < len(full_source_ids): + # Add truncation info from apply_source_ids_limit if truncation occurred + truncation_info_log = f"{limit_method} {len(source_ids)}/{len(full_source_ids)}" + if limit_method == SOURCE_IDS_LIMIT_METHOD_FIFO: + truncation_info = truncation_info_log + else: + truncation_info = "KEEP Old" + + deduplicated_num = already_fragment + len(nodes_data) - num_fragment + dd_message = "" + if deduplicated_num > 0: + # Duplicated description detected across multiple trucks for the same entity + dd_message = f"dd {deduplicated_num}" + + if dd_message or truncation_info_log: + status_message += ( + f" ({', '.join(filter(None, [truncation_info_log, dd_message]))})" + ) + + # Add message to pipeline satus when merge happens + if already_fragment > 0 or llm_was_used: + logger.info(status_message) + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = status_message + pipeline_status["history_messages"].append(status_message) + else: + logger.debug(status_message) + + # 11. Update both graph and vector db + node_data = dict( + entity_id=entity_name, + entity_type=entity_type, + description=description, + source_id=source_id, + file_path=file_path, + created_at=int(time.time()), + truncate=truncation_info, + ) + await knowledge_graph_inst.upsert_node( + entity_name, + node_data=node_data, + ) + node_data["entity_name"] = entity_name + if entity_vdb is not None: + entity_vdb_id = compute_mdhash_id(str(entity_name), prefix="ent-") + entity_content = f"{entity_name}\n{description}" + data_for_vdb = { + entity_vdb_id: { + "entity_name": entity_name, + "entity_type": entity_type, + "content": entity_content, + "source_id": source_id, + "file_path": file_path, + } + } + await safe_vdb_operation_with_exception( + operation=lambda payload=data_for_vdb: entity_vdb.upsert(payload), + operation_name="entity_upsert", + entity_name=entity_name, + max_retries=3, + retry_delay=0.1, + ) + return node_data + + +async def _merge_edges_then_upsert( + src_id: str, + tgt_id: str, + edges_data: list[dict], + knowledge_graph_inst: BaseGraphStorage, + relationships_vdb: BaseVectorStorage | None, + entity_vdb: BaseVectorStorage | None, + global_config: dict, + pipeline_status: dict = None, + pipeline_status_lock=None, + llm_response_cache: BaseKVStorage | None = None, + added_entities: list = None, # New parameter to track entities added during edge processing + relation_chunks_storage: BaseKVStorage | None = None, + entity_chunks_storage: BaseKVStorage | None = None, +): + if src_id == tgt_id: + return None + + already_edge = None + already_weights = [] + already_source_ids = [] + already_description = [] + already_keywords = [] + already_file_paths = [] + + # 1. Get existing edge data from graph storage + if await knowledge_graph_inst.has_edge(src_id, tgt_id): + already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id) + # Handle the case where get_edge returns None or missing fields + if already_edge: + # Get weight with default 1.0 if missing + already_weights.append(already_edge.get("weight", 1.0)) + + # Get source_id with empty string default if missing or None + if already_edge.get("source_id") is not None: + already_source_ids.extend( + already_edge["source_id"].split(GRAPH_FIELD_SEP) + ) + + # Get file_path with empty string default if missing or None + if already_edge.get("file_path") is not None: + already_file_paths.extend( + already_edge["file_path"].split(GRAPH_FIELD_SEP) + ) + + # Get description with empty string default if missing or None + if already_edge.get("description") is not None: + already_description.extend( + already_edge["description"].split(GRAPH_FIELD_SEP) + ) + + # Get keywords with empty string default if missing or None + if already_edge.get("keywords") is not None: + already_keywords.extend( + split_string_by_multi_markers( + already_edge["keywords"], [GRAPH_FIELD_SEP] + ) + ) + + new_source_ids = [dp["source_id"] for dp in edges_data if dp.get("source_id")] + + storage_key = make_relation_chunk_key(src_id, tgt_id) + existing_full_source_ids = [] + if relation_chunks_storage is not None: + stored_chunks = await relation_chunks_storage.get_by_id(storage_key) + if stored_chunks and isinstance(stored_chunks, dict): + existing_full_source_ids = [ + chunk_id for chunk_id in stored_chunks.get("chunk_ids", []) if chunk_id + ] + + if not existing_full_source_ids: + existing_full_source_ids = [ + chunk_id for chunk_id in already_source_ids if chunk_id + ] + + # 2. Merge new source ids with existing ones + full_source_ids = merge_source_ids(existing_full_source_ids, new_source_ids) + + if relation_chunks_storage is not None and full_source_ids: + await relation_chunks_storage.upsert( + { + storage_key: { + "chunk_ids": full_source_ids, + "count": len(full_source_ids), + } + } + ) + + # 3. Finalize source_id by applying source ids limit + limit_method = global_config.get("source_ids_limit_method") + max_source_limit = global_config.get("max_source_ids_per_relation") + source_ids = apply_source_ids_limit( + full_source_ids, + max_source_limit, + limit_method, + identifier=f"`{src_id}`~`{tgt_id}`", + ) + limit_method = ( + global_config.get("source_ids_limit_method") or SOURCE_IDS_LIMIT_METHOD_KEEP + ) + + # 4. Only keep edges with source_id in the final source_ids list if in KEEP mode + if limit_method == SOURCE_IDS_LIMIT_METHOD_KEEP: + allowed_source_ids = set(source_ids) + filtered_edges = [] + for dp in edges_data: + source_id = dp.get("source_id") + # Skip relationship fragments sourced from chunks dropped by keep oldest cap + if ( + source_id + and source_id not in allowed_source_ids + and source_id not in existing_full_source_ids + ): + continue + filtered_edges.append(dp) + edges_data = filtered_edges + else: # In FIFO mode, keep all edges - truncation happens at source_ids level only + edges_data = list(edges_data) + + # 5. Check if we need to skip summary due to source_ids limit + if ( + limit_method == SOURCE_IDS_LIMIT_METHOD_KEEP + and len(existing_full_source_ids) >= max_source_limit + and not edges_data + ): + if already_edge: + logger.info( + f"Skipped `{src_id}`~`{tgt_id}`: KEEP old chunks {already_source_ids}/{len(full_source_ids)}" + ) + existing_edge_data = dict(already_edge) + return existing_edge_data + else: + logger.error( + f"Internal Error: already_node missing for `{src_id}`~`{tgt_id}`" + ) + raise ValueError( + f"Internal Error: already_node missing for `{src_id}`~`{tgt_id}`" + ) + + # 6.1 Finalize source_id + source_id = GRAPH_FIELD_SEP.join(source_ids) + + # 6.2 Finalize weight by summing new edges and existing weights + weight = sum([dp["weight"] for dp in edges_data] + already_weights) + + # 6.2 Finalize keywords by merging existing and new keywords + all_keywords = set() + # Process already_keywords (which are comma-separated) + for keyword_str in already_keywords: + if keyword_str: # Skip empty strings + all_keywords.update(k.strip() for k in keyword_str.split(",") if k.strip()) + # Process new keywords from edges_data + for edge in edges_data: + if edge.get("keywords"): + all_keywords.update( + k.strip() for k in edge["keywords"].split(",") if k.strip() + ) + # Join all unique keywords with commas + keywords = ",".join(sorted(all_keywords)) + + # 7. Deduplicate by description, keeping first occurrence in the same document + unique_edges = {} + for dp in edges_data: + description_value = dp.get("description") + if not description_value: + continue + if description_value not in unique_edges: + unique_edges[description_value] = dp + + # Sort description by timestamp, then by description length (largest to smallest) when timestamps are the same + sorted_edges = sorted( + unique_edges.values(), + key=lambda x: (x.get("timestamp", 0), -len(x.get("description", ""))), + ) + sorted_descriptions = [dp["description"] for dp in sorted_edges] + + # Combine already_description with sorted new descriptions + description_list = already_description + sorted_descriptions + if not description_list: + logger.error(f"Relation {src_id}~{tgt_id} has no description") + raise ValueError(f"Relation {src_id}~{tgt_id} has no description") + + # Check for cancellation before LLM summary + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + if pipeline_status.get("cancellation_requested", False): + raise PipelineCancelledException( + "User cancelled during relation summary" + ) + + # 8. Get summary description an LLM usage status + description, llm_was_used = await _handle_entity_relation_summary( + "Relation", + f"({src_id}, {tgt_id})", + description_list, + GRAPH_FIELD_SEP, + global_config, + llm_response_cache, + ) + + # 9. Build file_path within MAX_FILE_PATHS limit + file_paths_list = [] + seen_paths = set() + has_placeholder = False # Track if already_file_paths contains placeholder + + max_file_paths = global_config.get("max_file_paths", DEFAULT_MAX_FILE_PATHS) + file_path_placeholder = global_config.get( + "file_path_more_placeholder", DEFAULT_FILE_PATH_MORE_PLACEHOLDER + ) + + # Collect from already_file_paths, excluding placeholder + for fp in already_file_paths: + # Check if this is a placeholder record + if fp and fp.startswith(f"...{file_path_placeholder}"): # Skip placeholders + has_placeholder = True + continue + if fp and fp not in seen_paths: + file_paths_list.append(fp) + seen_paths.add(fp) + + # Collect from new data + for dp in edges_data: + file_path_item = dp.get("file_path") + if file_path_item and file_path_item not in seen_paths: + file_paths_list.append(file_path_item) + seen_paths.add(file_path_item) + + # Apply count limit + if len(file_paths_list) > max_file_paths: + limit_method = global_config.get( + "source_ids_limit_method", SOURCE_IDS_LIMIT_METHOD_KEEP + ) + file_path_placeholder = global_config.get( + "file_path_more_placeholder", DEFAULT_FILE_PATH_MORE_PLACEHOLDER + ) + + # Add + sign to indicate actual file count is higher + original_count_str = ( + f"{len(file_paths_list)}+" if has_placeholder else str(len(file_paths_list)) + ) + + if limit_method == SOURCE_IDS_LIMIT_METHOD_FIFO: + # FIFO: keep tail (newest), discard head + file_paths_list = file_paths_list[-max_file_paths:] + file_paths_list.append(f"...{file_path_placeholder}...(FIFO)") + else: + # KEEP: keep head (earliest), discard tail + file_paths_list = file_paths_list[:max_file_paths] + file_paths_list.append(f"...{file_path_placeholder}...(KEEP Old)") + + logger.info( + f"Limited `{src_id}`~`{tgt_id}`: file_path {original_count_str} -> {max_file_paths} ({limit_method})" + ) + # Finalize file_path + file_path = GRAPH_FIELD_SEP.join(file_paths_list) + + # 10. Log based on actual LLM usage + num_fragment = len(description_list) + already_fragment = len(already_description) + if llm_was_used: + status_message = f"LLMmrg: `{src_id}`~`{tgt_id}` | {already_fragment}+{num_fragment - already_fragment}" + else: + status_message = f"Merged: `{src_id}`~`{tgt_id}` | {already_fragment}+{num_fragment - already_fragment}" + + truncation_info = truncation_info_log = "" + if len(source_ids) < len(full_source_ids): + # Add truncation info from apply_source_ids_limit if truncation occurred + truncation_info_log = f"{limit_method} {len(source_ids)}/{len(full_source_ids)}" + if limit_method == SOURCE_IDS_LIMIT_METHOD_FIFO: + truncation_info = truncation_info_log + else: + truncation_info = "KEEP Old" + + deduplicated_num = already_fragment + len(edges_data) - num_fragment + dd_message = "" + if deduplicated_num > 0: + # Duplicated description detected across multiple trucks for the same entity + dd_message = f"dd {deduplicated_num}" + + if dd_message or truncation_info_log: + status_message += ( + f" ({', '.join(filter(None, [truncation_info_log, dd_message]))})" + ) + + # Add message to pipeline satus when merge happens + if already_fragment > 0 or llm_was_used: + logger.info(status_message) + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = status_message + pipeline_status["history_messages"].append(status_message) + else: + logger.debug(status_message) + + # 11. Update both graph and vector db + for need_insert_id in [src_id, tgt_id]: + # Optimization: Use get_node instead of has_node + get_node + existing_node = await knowledge_graph_inst.get_node(need_insert_id) + + if existing_node is None: + # Node doesn't exist - create new node + node_created_at = int(time.time()) + node_data = { + "entity_id": need_insert_id, + "source_id": source_id, + "description": description, + "entity_type": "UNKNOWN", + "file_path": file_path, + "created_at": node_created_at, + "truncate": "", + } + await knowledge_graph_inst.upsert_node(need_insert_id, node_data=node_data) + + # Update entity_chunks_storage for the newly created entity + if entity_chunks_storage is not None: + chunk_ids = [chunk_id for chunk_id in full_source_ids if chunk_id] + if chunk_ids: + await entity_chunks_storage.upsert( + { + need_insert_id: { + "chunk_ids": chunk_ids, + "count": len(chunk_ids), + } + } + ) + + if entity_vdb is not None: + entity_vdb_id = compute_mdhash_id(need_insert_id, prefix="ent-") + entity_content = f"{need_insert_id}\n{description}" + vdb_data = { + entity_vdb_id: { + "content": entity_content, + "entity_name": need_insert_id, + "source_id": source_id, + "entity_type": "UNKNOWN", + "file_path": file_path, + } + } + await safe_vdb_operation_with_exception( + operation=lambda payload=vdb_data: entity_vdb.upsert(payload), + operation_name="added_entity_upsert", + entity_name=need_insert_id, + max_retries=3, + retry_delay=0.1, + ) + + # Track entities added during edge processing + if added_entities is not None: + entity_data = { + "entity_name": need_insert_id, + "entity_type": "UNKNOWN", + "description": description, + "source_id": source_id, + "file_path": file_path, + "created_at": node_created_at, + } + added_entities.append(entity_data) + else: + # Node exists - update its source_ids by merging with new source_ids + updated = False # Track if any update occurred + + # 1. Get existing full source_ids from entity_chunks_storage + existing_full_source_ids = [] + if entity_chunks_storage is not None: + stored_chunks = await entity_chunks_storage.get_by_id(need_insert_id) + if stored_chunks and isinstance(stored_chunks, dict): + existing_full_source_ids = [ + chunk_id + for chunk_id in stored_chunks.get("chunk_ids", []) + if chunk_id + ] + + # If not in entity_chunks_storage, get from graph database + if not existing_full_source_ids: + if existing_node.get("source_id"): + existing_full_source_ids = existing_node["source_id"].split( + GRAPH_FIELD_SEP + ) + + # 2. Merge with new source_ids from this relationship + new_source_ids_from_relation = [ + chunk_id for chunk_id in source_ids if chunk_id + ] + merged_full_source_ids = merge_source_ids( + existing_full_source_ids, new_source_ids_from_relation + ) + + # 3. Save merged full list to entity_chunks_storage (conditional) + if ( + entity_chunks_storage is not None + and merged_full_source_ids != existing_full_source_ids + ): + updated = True + await entity_chunks_storage.upsert( + { + need_insert_id: { + "chunk_ids": merged_full_source_ids, + "count": len(merged_full_source_ids), + } + } + ) + + # 4. Apply source_ids limit for graph and vector db + limit_method = global_config.get( + "source_ids_limit_method", SOURCE_IDS_LIMIT_METHOD_KEEP + ) + max_source_limit = global_config.get("max_source_ids_per_entity") + limited_source_ids = apply_source_ids_limit( + merged_full_source_ids, + max_source_limit, + limit_method, + identifier=f"`{need_insert_id}`", + ) + + # 5. Update graph database and vector database with limited source_ids (conditional) + limited_source_id_str = GRAPH_FIELD_SEP.join(limited_source_ids) + + if limited_source_id_str != existing_node.get("source_id", ""): + updated = True + updated_node_data = { + **existing_node, + "source_id": limited_source_id_str, + } + await knowledge_graph_inst.upsert_node( + need_insert_id, node_data=updated_node_data + ) + + # Update vector database + if entity_vdb is not None: + entity_vdb_id = compute_mdhash_id(need_insert_id, prefix="ent-") + entity_content = ( + f"{need_insert_id}\n{existing_node.get('description', '')}" + ) + vdb_data = { + entity_vdb_id: { + "content": entity_content, + "entity_name": need_insert_id, + "source_id": limited_source_id_str, + "entity_type": existing_node.get("entity_type", "UNKNOWN"), + "file_path": existing_node.get( + "file_path", "unknown_source" + ), + } + } + await safe_vdb_operation_with_exception( + operation=lambda payload=vdb_data: entity_vdb.upsert(payload), + operation_name="existing_entity_update", + entity_name=need_insert_id, + max_retries=3, + retry_delay=0.1, + ) + + # 6. Log once at the end if any update occurred + if updated: + status_message = f"Chunks appended from relation: `{need_insert_id}`" + logger.info(status_message) + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = status_message + pipeline_status["history_messages"].append(status_message) + + edge_created_at = int(time.time()) + await knowledge_graph_inst.upsert_edge( + src_id, + tgt_id, + edge_data=dict( + weight=weight, + description=description, + keywords=keywords, + source_id=source_id, + file_path=file_path, + created_at=edge_created_at, + truncate=truncation_info, + ), + ) + + edge_data = dict( + src_id=src_id, + tgt_id=tgt_id, + description=description, + keywords=keywords, + source_id=source_id, + file_path=file_path, + created_at=edge_created_at, + truncate=truncation_info, + weight=weight, + ) + + # Sort src_id and tgt_id to ensure consistent ordering (smaller string first) + if src_id > tgt_id: + src_id, tgt_id = tgt_id, src_id + + if relationships_vdb is not None: + rel_vdb_id = compute_mdhash_id(src_id + tgt_id, prefix="rel-") + rel_vdb_id_reverse = compute_mdhash_id(tgt_id + src_id, prefix="rel-") + try: + await relationships_vdb.delete([rel_vdb_id, rel_vdb_id_reverse]) + except Exception as e: + logger.debug( + f"Could not delete old relationship vector records {rel_vdb_id}, {rel_vdb_id_reverse}: {e}" + ) + rel_content = f"{keywords}\t{src_id}\n{tgt_id}\n{description}" + vdb_data = { + rel_vdb_id: { + "src_id": src_id, + "tgt_id": tgt_id, + "source_id": source_id, + "content": rel_content, + "keywords": keywords, + "description": description, + "weight": weight, + "file_path": file_path, + } + } + await safe_vdb_operation_with_exception( + operation=lambda payload=vdb_data: relationships_vdb.upsert(payload), + operation_name="relationship_upsert", + entity_name=f"{src_id}-{tgt_id}", + max_retries=3, + retry_delay=0.2, + ) + + return edge_data + + +async def merge_nodes_and_edges( + chunk_results: list, + knowledge_graph_inst: BaseGraphStorage, + entity_vdb: BaseVectorStorage, + relationships_vdb: BaseVectorStorage, + global_config: dict[str, str], + full_entities_storage: BaseKVStorage = None, + full_relations_storage: BaseKVStorage = None, + doc_id: str = None, + pipeline_status: dict = None, + pipeline_status_lock=None, + llm_response_cache: BaseKVStorage | None = None, + entity_chunks_storage: BaseKVStorage | None = None, + relation_chunks_storage: BaseKVStorage | None = None, + current_file_number: int = 0, + total_files: int = 0, + file_path: str = "unknown_source", +) -> None: + """Two-phase merge: process all entities first, then all relationships + + This approach ensures data consistency by: + 1. Phase 1: Process all entities concurrently + 2. Phase 2: Process all relationships concurrently (may add missing entities) + 3. Phase 3: Update full_entities and full_relations storage with final results + + Args: + chunk_results: List of tuples (maybe_nodes, maybe_edges) containing extracted entities and relationships + knowledge_graph_inst: Knowledge graph storage + entity_vdb: Entity vector database + relationships_vdb: Relationship vector database + global_config: Global configuration + full_entities_storage: Storage for document entity lists + full_relations_storage: Storage for document relation lists + doc_id: Document ID for storage indexing + pipeline_status: Pipeline status dictionary + pipeline_status_lock: Lock for pipeline status + llm_response_cache: LLM response cache + entity_chunks_storage: Storage tracking full chunk lists per entity + relation_chunks_storage: Storage tracking full chunk lists per relation + current_file_number: Current file number for logging + total_files: Total files for logging + file_path: File path for logging + """ + + # Check for cancellation at the start of merge + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + if pipeline_status.get("cancellation_requested", False): + raise PipelineCancelledException("User cancelled during merge phase") + + # Collect all nodes and edges from all chunks + all_nodes = defaultdict(list) + all_edges = defaultdict(list) + + for maybe_nodes, maybe_edges in chunk_results: + # Collect nodes + for entity_name, entities in maybe_nodes.items(): + all_nodes[entity_name].extend(entities) + + # Collect edges with sorted keys for undirected graph + for edge_key, edges in maybe_edges.items(): + sorted_edge_key = tuple(sorted(edge_key)) + all_edges[sorted_edge_key].extend(edges) + + total_entities_count = len(all_nodes) + total_relations_count = len(all_edges) + + log_message = f"Merging stage {current_file_number}/{total_files}: {file_path}" + logger.info(log_message) + async with pipeline_status_lock: + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + + # Get max async tasks limit from global_config for semaphore control + graph_max_async = global_config.get("llm_model_max_async", 4) * 2 + semaphore = asyncio.Semaphore(graph_max_async) + + # ===== Phase 1: Process all entities concurrently ===== + log_message = f"Phase 1: Processing {total_entities_count} entities from {doc_id} (async: {graph_max_async})" + logger.info(log_message) + async with pipeline_status_lock: + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + + async def _locked_process_entity_name(entity_name, entities): + async with semaphore: + # Check for cancellation before processing entity + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + if pipeline_status.get("cancellation_requested", False): + raise PipelineCancelledException( + "User cancelled during entity merge" + ) + + workspace = global_config.get("workspace", "") + namespace = f"{workspace}:GraphDB" if workspace else "GraphDB" + async with get_storage_keyed_lock( + [entity_name], namespace=namespace, enable_logging=False + ): + try: + logger.debug(f"Processing entity {entity_name}") + entity_data = await _merge_nodes_then_upsert( + entity_name, + entities, + knowledge_graph_inst, + entity_vdb, + global_config, + pipeline_status, + pipeline_status_lock, + llm_response_cache, + entity_chunks_storage, + ) + + return entity_data + + except Exception as e: + error_msg = f"Error processing entity `{entity_name}`: {e}" + logger.error(error_msg) + + # Try to update pipeline status, but don't let status update failure affect main exception + try: + if ( + pipeline_status is not None + and pipeline_status_lock is not None + ): + async with pipeline_status_lock: + pipeline_status["latest_message"] = error_msg + pipeline_status["history_messages"].append(error_msg) + except Exception as status_error: + logger.error( + f"Failed to update pipeline status: {status_error}" + ) + + # Re-raise the original exception with a prefix + prefixed_exception = create_prefixed_exception( + e, f"`{entity_name}`" + ) + raise prefixed_exception from e + + # Create entity processing tasks + entity_tasks = [] + for entity_name, entities in all_nodes.items(): + task = asyncio.create_task(_locked_process_entity_name(entity_name, entities)) + entity_tasks.append(task) + + # Execute entity tasks with error handling + processed_entities = [] + if entity_tasks: + done, pending = await asyncio.wait( + entity_tasks, return_when=asyncio.FIRST_EXCEPTION + ) + + first_exception = None + processed_entities = [] + + for task in done: + try: + result = task.result() + except BaseException as e: + if first_exception is None: + first_exception = e + else: + processed_entities.append(result) + + if pending: + for task in pending: + task.cancel() + pending_results = await asyncio.gather(*pending, return_exceptions=True) + for result in pending_results: + if isinstance(result, BaseException): + if first_exception is None: + first_exception = result + else: + processed_entities.append(result) + + if first_exception is not None: + raise first_exception + + # ===== Phase 2: Process all relationships concurrently ===== + log_message = f"Phase 2: Processing {total_relations_count} relations from {doc_id} (async: {graph_max_async})" + logger.info(log_message) + async with pipeline_status_lock: + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + + async def _locked_process_edges(edge_key, edges): + async with semaphore: + # Check for cancellation before processing edges + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + if pipeline_status.get("cancellation_requested", False): + raise PipelineCancelledException( + "User cancelled during relation merge" + ) + + workspace = global_config.get("workspace", "") + namespace = f"{workspace}:GraphDB" if workspace else "GraphDB" + sorted_edge_key = sorted([edge_key[0], edge_key[1]]) + + async with get_storage_keyed_lock( + sorted_edge_key, + namespace=namespace, + enable_logging=False, + ): + try: + added_entities = [] # Track entities added during edge processing + + logger.debug(f"Processing relation {sorted_edge_key}") + edge_data = await _merge_edges_then_upsert( + edge_key[0], + edge_key[1], + edges, + knowledge_graph_inst, + relationships_vdb, + entity_vdb, + global_config, + pipeline_status, + pipeline_status_lock, + llm_response_cache, + added_entities, # Pass list to collect added entities + relation_chunks_storage, + entity_chunks_storage, # Add entity_chunks_storage parameter + ) + + if edge_data is None: + return None, [] + + return edge_data, added_entities + + except Exception as e: + error_msg = f"Error processing relation `{sorted_edge_key}`: {e}" + logger.error(error_msg) + + # Try to update pipeline status, but don't let status update failure affect main exception + try: + if ( + pipeline_status is not None + and pipeline_status_lock is not None + ): + async with pipeline_status_lock: + pipeline_status["latest_message"] = error_msg + pipeline_status["history_messages"].append(error_msg) + except Exception as status_error: + logger.error( + f"Failed to update pipeline status: {status_error}" + ) + + # Re-raise the original exception with a prefix + prefixed_exception = create_prefixed_exception( + e, f"{sorted_edge_key}" + ) + raise prefixed_exception from e + + # Create relationship processing tasks + edge_tasks = [] + for edge_key, edges in all_edges.items(): + task = asyncio.create_task(_locked_process_edges(edge_key, edges)) + edge_tasks.append(task) + + # Execute relationship tasks with error handling + processed_edges = [] + all_added_entities = [] + + if edge_tasks: + done, pending = await asyncio.wait( + edge_tasks, return_when=asyncio.FIRST_EXCEPTION + ) + + first_exception = None + + for task in done: + try: + edge_data, added_entities = task.result() + except BaseException as e: + if first_exception is None: + first_exception = e + else: + if edge_data is not None: + processed_edges.append(edge_data) + all_added_entities.extend(added_entities) + + if pending: + for task in pending: + task.cancel() + pending_results = await asyncio.gather(*pending, return_exceptions=True) + for result in pending_results: + if isinstance(result, BaseException): + if first_exception is None: + first_exception = result + else: + edge_data, added_entities = result + if edge_data is not None: + processed_edges.append(edge_data) + all_added_entities.extend(added_entities) + + if first_exception is not None: + raise first_exception + + # ===== Phase 3: Update full_entities and full_relations storage ===== + if full_entities_storage and full_relations_storage and doc_id: + try: + # Merge all entities: original entities + entities added during edge processing + final_entity_names = set() + + # Add original processed entities + for entity_data in processed_entities: + if entity_data and entity_data.get("entity_name"): + final_entity_names.add(entity_data["entity_name"]) + + # Add entities that were added during relationship processing + for added_entity in all_added_entities: + if added_entity and added_entity.get("entity_name"): + final_entity_names.add(added_entity["entity_name"]) + + # Collect all relation pairs + final_relation_pairs = set() + for edge_data in processed_edges: + if edge_data: + src_id = edge_data.get("src_id") + tgt_id = edge_data.get("tgt_id") + if src_id and tgt_id: + relation_pair = tuple(sorted([src_id, tgt_id])) + final_relation_pairs.add(relation_pair) + + log_message = f"Phase 3: Updating final {len(final_entity_names)}({len(processed_entities)}+{len(all_added_entities)}) entities and {len(final_relation_pairs)} relations from {doc_id}" + logger.info(log_message) + async with pipeline_status_lock: + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + + # Update storage + if final_entity_names: + await full_entities_storage.upsert( + { + doc_id: { + "entity_names": list(final_entity_names), + "count": len(final_entity_names), + } + } + ) + + if final_relation_pairs: + await full_relations_storage.upsert( + { + doc_id: { + "relation_pairs": [ + list(pair) for pair in final_relation_pairs + ], + "count": len(final_relation_pairs), + } + } + ) + + logger.debug( + f"Updated entity-relation index for document {doc_id}: {len(final_entity_names)} entities (original: {len(processed_entities)}, added: {len(all_added_entities)}), {len(final_relation_pairs)} relations" + ) + + except Exception as e: + logger.error( + f"Failed to update entity-relation index for document {doc_id}: {e}" + ) + # Don't raise exception to avoid affecting main flow + + log_message = f"Completed merging: {len(processed_entities)} entities, {len(all_added_entities)} extra entities, {len(processed_edges)} relations" + logger.info(log_message) + async with pipeline_status_lock: + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + + +async def extract_entities( + chunks: dict[str, TextChunkSchema], + global_config: dict[str, str], + pipeline_status: dict = None, + pipeline_status_lock=None, + llm_response_cache: BaseKVStorage | None = None, + text_chunks_storage: BaseKVStorage | None = None, +) -> list: + # Check for cancellation at the start of entity extraction + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + if pipeline_status.get("cancellation_requested", False): + raise PipelineCancelledException( + "User cancelled during entity extraction" + ) + + use_llm_func: callable = global_config["llm_model_func"] + entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] + + ordered_chunks = list(chunks.items()) + # add language and example number params to prompt + language = global_config["addon_params"].get("language", DEFAULT_SUMMARY_LANGUAGE) + entity_types = global_config["addon_params"].get( + "entity_types", DEFAULT_ENTITY_TYPES + ) + + examples = "\n".join(PROMPTS["entity_extraction_examples"]) + + example_context_base = dict( + tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"], + completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"], + entity_types=", ".join(entity_types), + language=language, + ) + # add example's format + examples = examples.format(**example_context_base) + + context_base = dict( + tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"], + completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"], + entity_types=",".join(entity_types), + examples=examples, + language=language, + ) + + processed_chunks = 0 + total_chunks = len(ordered_chunks) + + async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): + """Process a single chunk + Args: + chunk_key_dp (tuple[str, TextChunkSchema]): + ("chunk-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int}) + Returns: + tuple: (maybe_nodes, maybe_edges) containing extracted entities and relationships + """ + nonlocal processed_chunks + chunk_key = chunk_key_dp[0] + chunk_dp = chunk_key_dp[1] + content = chunk_dp["content"] + # Get file path from chunk data or use default + file_path = chunk_dp.get("file_path", "unknown_source") + + # Create cache keys collector for batch processing + cache_keys_collector = [] + + # Get initial extraction + # Format system prompt without input_text for each chunk (enables OpenAI prompt caching across chunks) + entity_extraction_system_prompt = PROMPTS[ + "entity_extraction_system_prompt" + ].format(**context_base) + # Format user prompts with input_text for each chunk + entity_extraction_user_prompt = PROMPTS["entity_extraction_user_prompt"].format( + **{**context_base, "input_text": content} + ) + entity_continue_extraction_user_prompt = PROMPTS[ + "entity_continue_extraction_user_prompt" + ].format(**{**context_base, "input_text": content}) + + final_result, timestamp = await use_llm_func_with_cache( + entity_extraction_user_prompt, + use_llm_func, + system_prompt=entity_extraction_system_prompt, + llm_response_cache=llm_response_cache, + cache_type="extract", + chunk_id=chunk_key, + cache_keys_collector=cache_keys_collector, + ) + + history = pack_user_ass_to_openai_messages( + entity_extraction_user_prompt, final_result + ) + + # Process initial extraction with file path + maybe_nodes, maybe_edges = await _process_extraction_result( + final_result, + chunk_key, + timestamp, + file_path, + tuple_delimiter=context_base["tuple_delimiter"], + completion_delimiter=context_base["completion_delimiter"], + ) + + # Process additional gleaning results only 1 time when entity_extract_max_gleaning is greater than zero. + if entity_extract_max_gleaning > 0: + # Calculate total tokens for the gleaning request to prevent context window overflow + tokenizer = global_config["tokenizer"] + max_input_tokens = global_config["max_extract_input_tokens"] + + # Approximate total tokens: system prompt + history + user prompt. + # This slightly underestimates actual API usage (missing role/framing tokens) + # but is sufficient as a safety guard against context window overflow. + history_str = json.dumps(history, ensure_ascii=False) + full_context_str = ( + entity_extraction_system_prompt + + history_str + + entity_continue_extraction_user_prompt + ) + token_count = len(tokenizer.encode(full_context_str)) + + if token_count > max_input_tokens: + logger.warning( + f"Gleaning stopped for chunk {chunk_key}: Input tokens ({token_count}) exceeded limit ({max_input_tokens})." + ) + else: + glean_result, timestamp = await use_llm_func_with_cache( + entity_continue_extraction_user_prompt, + use_llm_func, + system_prompt=entity_extraction_system_prompt, + llm_response_cache=llm_response_cache, + history_messages=history, + cache_type="extract", + chunk_id=chunk_key, + cache_keys_collector=cache_keys_collector, + ) + + # Process gleaning result separately with file path + glean_nodes, glean_edges = await _process_extraction_result( + glean_result, + chunk_key, + timestamp, + file_path, + tuple_delimiter=context_base["tuple_delimiter"], + completion_delimiter=context_base["completion_delimiter"], + ) + + # Merge results - compare description lengths to choose better version + for entity_name, glean_entities in glean_nodes.items(): + if entity_name in maybe_nodes: + # Compare description lengths and keep the better one + original_desc_len = len( + maybe_nodes[entity_name][0].get("description", "") or "" + ) + glean_desc_len = len( + glean_entities[0].get("description", "") or "" + ) + + if glean_desc_len > original_desc_len: + maybe_nodes[entity_name] = list(glean_entities) + # Otherwise keep original version + else: + # New entity from gleaning stage + maybe_nodes[entity_name] = list(glean_entities) + + for edge_key, glean_edge_list in glean_edges.items(): + if edge_key in maybe_edges: + # Compare description lengths and keep the better one + original_desc_len = len( + maybe_edges[edge_key][0].get("description", "") or "" + ) + glean_desc_len = len( + glean_edge_list[0].get("description", "") or "" + ) + + if glean_desc_len > original_desc_len: + maybe_edges[edge_key] = list(glean_edge_list) + # Otherwise keep original version + else: + # New edge from gleaning stage + maybe_edges[edge_key] = list(glean_edge_list) + + # Batch update chunk's llm_cache_list with all collected cache keys + if cache_keys_collector and text_chunks_storage: + await update_chunk_cache_list( + chunk_key, + text_chunks_storage, + cache_keys_collector, + "entity_extraction", + ) + + processed_chunks += 1 + entities_count = len(maybe_nodes) + relations_count = len(maybe_edges) + log_message = f"Chunk {processed_chunks} of {total_chunks} extracted {entities_count} Ent + {relations_count} Rel {chunk_key}" + logger.info(log_message) + if pipeline_status is not None: + async with pipeline_status_lock: + pipeline_status["latest_message"] = log_message + pipeline_status["history_messages"].append(log_message) + + # Return the extracted nodes and edges for centralized processing + return maybe_nodes, maybe_edges + + # Get max async tasks limit from global_config + chunk_max_async = global_config.get("llm_model_max_async", 4) + semaphore = asyncio.Semaphore(chunk_max_async) + + async def _process_with_semaphore(chunk): + async with semaphore: + # Check for cancellation before processing chunk + if pipeline_status is not None and pipeline_status_lock is not None: + async with pipeline_status_lock: + if pipeline_status.get("cancellation_requested", False): + raise PipelineCancelledException( + "User cancelled during chunk processing" + ) + + try: + return await _process_single_content(chunk) + except Exception as e: + chunk_id = chunk[0] # Extract chunk_id from chunk[0] + prefixed_exception = create_prefixed_exception(e, chunk_id) + raise prefixed_exception from e + + tasks = [] + for c in ordered_chunks: + task = asyncio.create_task(_process_with_semaphore(c)) + tasks.append(task) + + # Wait for tasks to complete or for the first exception to occur + # This allows us to cancel remaining tasks if any task fails + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION) + + # Check if any task raised an exception and ensure all exceptions are retrieved + first_exception = None + chunk_results = [] + + for task in done: + try: + exception = task.exception() + if exception is not None: + if first_exception is None: + first_exception = exception + else: + chunk_results.append(task.result()) + except Exception as e: + if first_exception is None: + first_exception = e + + # If any task failed, cancel all pending tasks and raise the first exception + if first_exception is not None: + # Cancel all pending tasks + for pending_task in pending: + pending_task.cancel() + + # Wait for cancellation to complete + if pending: + await asyncio.wait(pending) + + # Add progress prefix to the exception message + progress_prefix = f"C[{processed_chunks + 1}/{total_chunks}]" + + # Re-raise the original exception with a prefix + prefixed_exception = create_prefixed_exception(first_exception, progress_prefix) + raise prefixed_exception from first_exception + + # If all tasks completed successfully, chunk_results already contains the results + # Return the chunk_results for later processing in merge_nodes_and_edges + return chunk_results + + +async def kg_query( + query: str, + knowledge_graph_inst: BaseGraphStorage, + entities_vdb: BaseVectorStorage, + relationships_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage, + query_param: QueryParam, + global_config: dict[str, str], + hashing_kv: BaseKVStorage | None = None, + system_prompt: str | None = None, + chunks_vdb: BaseVectorStorage = None, +) -> QueryResult | None: + """ + Execute knowledge graph query and return unified QueryResult object. + + Args: + query: Query string + knowledge_graph_inst: Knowledge graph storage instance + entities_vdb: Entity vector database + relationships_vdb: Relationship vector database + text_chunks_db: Text chunks storage + query_param: Query parameters + global_config: Global configuration + hashing_kv: Cache storage + system_prompt: System prompt + chunks_vdb: Document chunks vector database + + Returns: + QueryResult | None: Unified query result object containing: + - content: Non-streaming response text content + - response_iterator: Streaming response iterator + - raw_data: Complete structured data (including references and metadata) + - is_streaming: Whether this is a streaming result + + Based on different query_param settings, different fields will be populated: + - only_need_context=True: content contains context string + - only_need_prompt=True: content contains complete prompt + - stream=True: response_iterator contains streaming response, raw_data contains complete data + - default: content contains LLM response text, raw_data contains complete data + + Returns None when no relevant context could be constructed for the query. + """ + if not query: + return QueryResult(content=PROMPTS["fail_response"]) + + if query_param.model_func: + use_model_func = query_param.model_func + else: + use_model_func = global_config["llm_model_func"] + # Apply higher priority (5) to query relation LLM function + use_model_func = partial(use_model_func, _priority=5) + + hl_keywords, ll_keywords = await get_keywords_from_query( + query, query_param, global_config, hashing_kv + ) + + logger.debug(f"High-level keywords: {hl_keywords}") + logger.debug(f"Low-level keywords: {ll_keywords}") + + # Handle empty keywords + if ll_keywords == [] and query_param.mode in ["local", "hybrid", "mix"]: + logger.warning("low_level_keywords is empty") + if hl_keywords == [] and query_param.mode in ["global", "hybrid", "mix"]: + logger.warning("high_level_keywords is empty") + if hl_keywords == [] and ll_keywords == []: + if len(query) < 50: + logger.warning(f"Forced low_level_keywords to origin query: {query}") + ll_keywords = [query] + else: + return QueryResult(content=PROMPTS["fail_response"]) + + ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else "" + hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else "" + + # Build query context (unified interface) + context_result = await _build_query_context( + query, + ll_keywords_str, + hl_keywords_str, + knowledge_graph_inst, + entities_vdb, + relationships_vdb, + text_chunks_db, + query_param, + chunks_vdb, + ) + + if context_result is None: + logger.info("[kg_query] No query context could be built; returning no-result.") + return None + + # Return different content based on query parameters + if query_param.only_need_context and not query_param.only_need_prompt: + return QueryResult( + content=context_result.context, raw_data=context_result.raw_data + ) + + user_prompt = f"\n\n{query_param.user_prompt}" if query_param.user_prompt else "n/a" + response_type = ( + query_param.response_type + if query_param.response_type + else "Multiple Paragraphs" + ) + + # Build system prompt + sys_prompt_temp = system_prompt if system_prompt else PROMPTS["rag_response"] + sys_prompt = sys_prompt_temp.format( + response_type=response_type, + user_prompt=user_prompt, + context_data=context_result.context, + ) + + user_query = query + + if query_param.only_need_prompt: + prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query]) + return QueryResult(content=prompt_content, raw_data=context_result.raw_data) + + # Call LLM + tokenizer: Tokenizer = global_config["tokenizer"] + len_of_prompts = len(tokenizer.encode(query + sys_prompt)) + logger.debug( + f"[kg_query] Sending to LLM: {len_of_prompts:,} tokens (Query: {len(tokenizer.encode(query))}, System: {len(tokenizer.encode(sys_prompt))})" + ) + + # Handle cache + args_hash = compute_args_hash( + query_param.mode, + query, + query_param.response_type, + query_param.top_k, + query_param.chunk_top_k, + query_param.max_entity_tokens, + query_param.max_relation_tokens, + query_param.max_total_tokens, + hl_keywords_str, + ll_keywords_str, + query_param.user_prompt or "", + query_param.enable_rerank, + ) + + cached_result = await handle_cache( + hashing_kv, args_hash, user_query, query_param.mode, cache_type="query" + ) + + if cached_result is not None: + cached_response, _ = cached_result # Extract content, ignore timestamp + logger.info( + " == LLM cache == Query cache hit, using cached response as query result" + ) + response = cached_response + else: + response = await use_model_func( + user_query, + system_prompt=sys_prompt, + history_messages=query_param.conversation_history, + enable_cot=True, + stream=query_param.stream, + ) + + if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"): + queryparam_dict = { + "mode": query_param.mode, + "response_type": query_param.response_type, + "top_k": query_param.top_k, + "chunk_top_k": query_param.chunk_top_k, + "max_entity_tokens": query_param.max_entity_tokens, + "max_relation_tokens": query_param.max_relation_tokens, + "max_total_tokens": query_param.max_total_tokens, + "hl_keywords": hl_keywords_str, + "ll_keywords": ll_keywords_str, + "user_prompt": query_param.user_prompt or "", + "enable_rerank": query_param.enable_rerank, + } + await save_to_cache( + hashing_kv, + CacheData( + args_hash=args_hash, + content=response, + prompt=query, + mode=query_param.mode, + cache_type="query", + queryparam=queryparam_dict, + ), + ) + + # Return unified result based on actual response type + if isinstance(response, str): + # Non-streaming response (string) + if len(response) > len(sys_prompt): + response = ( + response.replace(sys_prompt, "") + .replace("user", "") + .replace("model", "") + .replace(query, "") + .replace("", "") + .replace("", "") + .strip() + ) + + return QueryResult(content=response, raw_data=context_result.raw_data) + else: + # Streaming response (AsyncIterator) + return QueryResult( + response_iterator=response, + raw_data=context_result.raw_data, + is_streaming=True, + ) + + +async def get_keywords_from_query( + query: str, + query_param: QueryParam, + global_config: dict[str, str], + hashing_kv: BaseKVStorage | None = None, +) -> tuple[list[str], list[str]]: + """ + Retrieves high-level and low-level keywords for RAG operations. + + This function checks if keywords are already provided in query parameters, + and if not, extracts them from the query text using LLM. + + Args: + query: The user's query text + query_param: Query parameters that may contain pre-defined keywords + global_config: Global configuration dictionary + hashing_kv: Optional key-value storage for caching results + + Returns: + A tuple containing (high_level_keywords, low_level_keywords) + """ + # Check if pre-defined keywords are already provided + if query_param.hl_keywords or query_param.ll_keywords: + return query_param.hl_keywords, query_param.ll_keywords + + # Extract keywords using extract_keywords_only function which already supports conversation history + hl_keywords, ll_keywords = await extract_keywords_only( + query, query_param, global_config, hashing_kv + ) + return hl_keywords, ll_keywords + + +async def extract_keywords_only( + text: str, + param: QueryParam, + global_config: dict[str, str], + hashing_kv: BaseKVStorage | None = None, +) -> tuple[list[str], list[str]]: + """ + Extract high-level and low-level keywords from the given 'text' using the LLM. + This method does NOT build the final RAG context or provide a final answer. + It ONLY extracts keywords (hl_keywords, ll_keywords). + """ + + # 1. Build the examples + examples = "\n".join(PROMPTS["keywords_extraction_examples"]) + + language = global_config["addon_params"].get("language", DEFAULT_SUMMARY_LANGUAGE) + + # 2. Handle cache if needed - add cache type for keywords + args_hash = compute_args_hash( + param.mode, + text, + language, + ) + cached_result = await handle_cache( + hashing_kv, args_hash, text, param.mode, cache_type="keywords" + ) + if cached_result is not None: + cached_response, _ = cached_result # Extract content, ignore timestamp + try: + keywords_data = json_repair.loads(cached_response) + return keywords_data.get("high_level_keywords", []), keywords_data.get( + "low_level_keywords", [] + ) + except (json.JSONDecodeError, KeyError): + logger.warning( + "Invalid cache format for keywords, proceeding with extraction" + ) + + # 3. Build the keyword-extraction prompt + kw_prompt = PROMPTS["keywords_extraction"].format( + query=text, + examples=examples, + language=language, + ) + + tokenizer: Tokenizer = global_config["tokenizer"] + len_of_prompts = len(tokenizer.encode(kw_prompt)) + logger.debug( + f"[extract_keywords] Sending to LLM: {len_of_prompts:,} tokens (Prompt: {len_of_prompts})" + ) + + # 4. Call the LLM for keyword extraction + if param.model_func: + use_model_func = param.model_func + else: + use_model_func = global_config["llm_model_func"] + # Apply higher priority (5) to query relation LLM function + use_model_func = partial(use_model_func, _priority=5) + + result = await use_model_func(kw_prompt, keyword_extraction=True) + + # 5. Parse out JSON from the LLM response + result = remove_think_tags(result) + try: + keywords_data = json_repair.loads(result) + if not keywords_data: + logger.error("No JSON-like structure found in the LLM respond.") + return [], [] + except json.JSONDecodeError as e: + logger.error(f"JSON parsing error: {e}") + logger.error(f"LLM respond: {result}") + return [], [] + + hl_keywords = keywords_data.get("high_level_keywords", []) + ll_keywords = keywords_data.get("low_level_keywords", []) + + # 6. Cache only the processed keywords with cache type + if hl_keywords or ll_keywords: + cache_data = { + "high_level_keywords": hl_keywords, + "low_level_keywords": ll_keywords, + } + if hashing_kv.global_config.get("enable_llm_cache"): + # Save to cache with query parameters + queryparam_dict = { + "mode": param.mode, + "response_type": param.response_type, + "top_k": param.top_k, + "chunk_top_k": param.chunk_top_k, + "max_entity_tokens": param.max_entity_tokens, + "max_relation_tokens": param.max_relation_tokens, + "max_total_tokens": param.max_total_tokens, + "user_prompt": param.user_prompt or "", + "enable_rerank": param.enable_rerank, + } + await save_to_cache( + hashing_kv, + CacheData( + args_hash=args_hash, + content=json.dumps(cache_data), + prompt=text, + mode=param.mode, + cache_type="keywords", + queryparam=queryparam_dict, + ), + ) + + return hl_keywords, ll_keywords + + +async def _get_vector_context( + query: str, + chunks_vdb: BaseVectorStorage, + query_param: QueryParam, + query_embedding: list[float] = None, +) -> list[dict]: + """ + Retrieve text chunks from the vector database without reranking or truncation. + + This function performs vector search to find relevant text chunks for a query. + Reranking and truncation will be handled later in the unified processing. + + Args: + query: The query string to search for + chunks_vdb: Vector database containing document chunks + query_param: Query parameters including chunk_top_k and ids + query_embedding: Optional pre-computed query embedding to avoid redundant embedding calls + + Returns: + List of text chunks with metadata + """ + try: + # Use chunk_top_k if specified, otherwise fall back to top_k + search_top_k = query_param.chunk_top_k or query_param.top_k + cosine_threshold = chunks_vdb.cosine_better_than_threshold + + results = await chunks_vdb.query( + query, top_k=search_top_k, query_embedding=query_embedding + ) + if not results: + logger.info( + f"Naive query: 0 chunks (chunk_top_k:{search_top_k} cosine:{cosine_threshold})" + ) + return [] + + valid_chunks = [] + for result in results: + if "content" in result: + chunk_with_metadata = { + "content": result["content"], + "created_at": result.get("created_at", None), + "file_path": result.get("file_path", "unknown_source"), + "source_type": "vector", # Mark the source type + "chunk_id": result.get("id"), # Add chunk_id for deduplication + } + valid_chunks.append(chunk_with_metadata) + + logger.info( + f"Naive query: {len(valid_chunks)} chunks (chunk_top_k:{search_top_k} cosine:{cosine_threshold})" + ) + return valid_chunks + + except Exception as e: + logger.error(f"Error in _get_vector_context: {e}") + return [] + + +async def _perform_kg_search( + query: str, + ll_keywords: str, + hl_keywords: str, + knowledge_graph_inst: BaseGraphStorage, + entities_vdb: BaseVectorStorage, + relationships_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage, + query_param: QueryParam, + chunks_vdb: BaseVectorStorage = None, +) -> dict[str, Any]: + """ + Pure search logic that retrieves raw entities, relations, and vector chunks. + No token truncation or formatting - just raw search results. + """ + + # Initialize result containers + local_entities = [] + local_relations = [] + global_entities = [] + global_relations = [] + vector_chunks = [] + chunk_tracking = {} + + # Handle different query modes + + # Track chunk sources and metadata for final logging + chunk_tracking = {} # chunk_id -> {source, frequency, order} + + # Pre-compute query embedding once for all vector operations + kg_chunk_pick_method = text_chunks_db.global_config.get( + "kg_chunk_pick_method", DEFAULT_KG_CHUNK_PICK_METHOD + ) + query_embedding = None + if query and (kg_chunk_pick_method == "VECTOR" or chunks_vdb): + actual_embedding_func = text_chunks_db.embedding_func + if actual_embedding_func: + try: + query_embedding = await actual_embedding_func([query]) + query_embedding = query_embedding[ + 0 + ] # Extract first embedding from batch result + logger.debug("Pre-computed query embedding for all vector operations") + except Exception as e: + logger.warning(f"Failed to pre-compute query embedding: {e}") + query_embedding = None + + # Handle local and global modes + if query_param.mode == "local" and len(ll_keywords) > 0: + local_entities, local_relations = await _get_node_data( + ll_keywords, + knowledge_graph_inst, + entities_vdb, + query_param, + ) + + elif query_param.mode == "global" and len(hl_keywords) > 0: + global_relations, global_entities = await _get_edge_data( + hl_keywords, + knowledge_graph_inst, + relationships_vdb, + query_param, + ) + + else: # hybrid or mix mode + if len(ll_keywords) > 0: + local_entities, local_relations = await _get_node_data( + ll_keywords, + knowledge_graph_inst, + entities_vdb, + query_param, + ) + if len(hl_keywords) > 0: + global_relations, global_entities = await _get_edge_data( + hl_keywords, + knowledge_graph_inst, + relationships_vdb, + query_param, + ) + + # Get vector chunks for mix mode + if query_param.mode == "mix" and chunks_vdb: + vector_chunks = await _get_vector_context( + query, + chunks_vdb, + query_param, + query_embedding, + ) + # Track vector chunks with source metadata + for i, chunk in enumerate(vector_chunks): + chunk_id = chunk.get("chunk_id") or chunk.get("id") + if chunk_id: + chunk_tracking[chunk_id] = { + "source": "C", + "frequency": 1, # Vector chunks always have frequency 1 + "order": i + 1, # 1-based order in vector search results + } + else: + logger.warning(f"Vector chunk missing chunk_id: {chunk}") + + # Round-robin merge entities + final_entities = [] + seen_entities = set() + max_len = max(len(local_entities), len(global_entities)) + for i in range(max_len): + # First from local + if i < len(local_entities): + entity = local_entities[i] + entity_name = entity.get("entity_name") + if entity_name and entity_name not in seen_entities: + final_entities.append(entity) + seen_entities.add(entity_name) + + # Then from global + if i < len(global_entities): + entity = global_entities[i] + entity_name = entity.get("entity_name") + if entity_name and entity_name not in seen_entities: + final_entities.append(entity) + seen_entities.add(entity_name) + + # Round-robin merge relations + final_relations = [] + seen_relations = set() + max_len = max(len(local_relations), len(global_relations)) + for i in range(max_len): + # First from local + if i < len(local_relations): + relation = local_relations[i] + # Build relation unique identifier + if "src_tgt" in relation: + rel_key = tuple(sorted(relation["src_tgt"])) + else: + rel_key = tuple( + sorted([relation.get("src_id"), relation.get("tgt_id")]) + ) + + if rel_key not in seen_relations: + final_relations.append(relation) + seen_relations.add(rel_key) + + # Then from global + if i < len(global_relations): + relation = global_relations[i] + # Build relation unique identifier + if "src_tgt" in relation: + rel_key = tuple(sorted(relation["src_tgt"])) + else: + rel_key = tuple( + sorted([relation.get("src_id"), relation.get("tgt_id")]) + ) + + if rel_key not in seen_relations: + final_relations.append(relation) + seen_relations.add(rel_key) + + logger.info( + f"Raw search results: {len(final_entities)} entities, {len(final_relations)} relations, {len(vector_chunks)} vector chunks" + ) + + return { + "final_entities": final_entities, + "final_relations": final_relations, + "vector_chunks": vector_chunks, + "chunk_tracking": chunk_tracking, + "query_embedding": query_embedding, + } + + +async def _apply_token_truncation( + search_result: dict[str, Any], + query_param: QueryParam, + global_config: dict[str, str], +) -> dict[str, Any]: + """ + Apply token-based truncation to entities and relations for LLM efficiency. + """ + tokenizer = global_config.get("tokenizer") + if not tokenizer: + logger.warning("No tokenizer found, skipping truncation") + return { + "entities_context": [], + "relations_context": [], + "filtered_entities": search_result["final_entities"], + "filtered_relations": search_result["final_relations"], + "entity_id_to_original": {}, + "relation_id_to_original": {}, + } + + # Get token limits from query_param with fallbacks + max_entity_tokens = getattr( + query_param, + "max_entity_tokens", + global_config.get("max_entity_tokens", DEFAULT_MAX_ENTITY_TOKENS), + ) + max_relation_tokens = getattr( + query_param, + "max_relation_tokens", + global_config.get("max_relation_tokens", DEFAULT_MAX_RELATION_TOKENS), + ) + + final_entities = search_result["final_entities"] + final_relations = search_result["final_relations"] + + # Create mappings from entity/relation identifiers to original data + entity_id_to_original = {} + relation_id_to_original = {} + + # Generate entities context for truncation + entities_context = [] + for i, entity in enumerate(final_entities): + entity_name = entity["entity_name"] + created_at = entity.get("created_at", "UNKNOWN") + if isinstance(created_at, (int, float)): + created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) + + # Store mapping from entity name to original data + entity_id_to_original[entity_name] = entity + + entities_context.append( + { + "entity": entity_name, + "type": entity.get("entity_type", "UNKNOWN"), + "description": entity.get("description", "UNKNOWN"), + "created_at": created_at, + "file_path": entity.get("file_path", "unknown_source"), + } + ) + + # Generate relations context for truncation + relations_context = [] + for i, relation in enumerate(final_relations): + created_at = relation.get("created_at", "UNKNOWN") + if isinstance(created_at, (int, float)): + created_at = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(created_at)) + + # Handle different relation data formats + if "src_tgt" in relation: + entity1, entity2 = relation["src_tgt"] + else: + entity1, entity2 = relation.get("src_id"), relation.get("tgt_id") + + # Store mapping from relation pair to original data + relation_key = (entity1, entity2) + relation_id_to_original[relation_key] = relation + + relations_context.append( + { + "entity1": entity1, + "entity2": entity2, + "description": relation.get("description", "UNKNOWN"), + "created_at": created_at, + "file_path": relation.get("file_path", "unknown_source"), + } + ) + + logger.debug( + f"Before truncation: {len(entities_context)} entities, {len(relations_context)} relations" + ) + + # Apply token-based truncation + if entities_context: + # Remove file_path and created_at for token calculation + entities_context_for_truncation = [] + for entity in entities_context: + entity_copy = entity.copy() + entity_copy.pop("file_path", None) + entity_copy.pop("created_at", None) + entities_context_for_truncation.append(entity_copy) + + entities_context = truncate_list_by_token_size( + entities_context_for_truncation, + key=lambda x: "\n".join( + json.dumps(item, ensure_ascii=False) for item in [x] + ), + max_token_size=max_entity_tokens, + tokenizer=tokenizer, + ) + + if relations_context: + # Remove file_path and created_at for token calculation + relations_context_for_truncation = [] + for relation in relations_context: + relation_copy = relation.copy() + relation_copy.pop("file_path", None) + relation_copy.pop("created_at", None) + relations_context_for_truncation.append(relation_copy) + + relations_context = truncate_list_by_token_size( + relations_context_for_truncation, + key=lambda x: "\n".join( + json.dumps(item, ensure_ascii=False) for item in [x] + ), + max_token_size=max_relation_tokens, + tokenizer=tokenizer, + ) + + logger.info( + f"After truncation: {len(entities_context)} entities, {len(relations_context)} relations" + ) + + # Create filtered original data based on truncated context + filtered_entities = [] + filtered_entity_id_to_original = {} + if entities_context: + final_entity_names = {e["entity"] for e in entities_context} + seen_nodes = set() + for entity in final_entities: + name = entity.get("entity_name") + if name in final_entity_names and name not in seen_nodes: + filtered_entities.append(entity) + filtered_entity_id_to_original[name] = entity + seen_nodes.add(name) + + filtered_relations = [] + filtered_relation_id_to_original = {} + if relations_context: + final_relation_pairs = {(r["entity1"], r["entity2"]) for r in relations_context} + seen_edges = set() + for relation in final_relations: + src, tgt = relation.get("src_id"), relation.get("tgt_id") + if src is None or tgt is None: + src, tgt = relation.get("src_tgt", (None, None)) + + pair = (src, tgt) + if pair in final_relation_pairs and pair not in seen_edges: + filtered_relations.append(relation) + filtered_relation_id_to_original[pair] = relation + seen_edges.add(pair) + + return { + "entities_context": entities_context, + "relations_context": relations_context, + "filtered_entities": filtered_entities, + "filtered_relations": filtered_relations, + "entity_id_to_original": filtered_entity_id_to_original, + "relation_id_to_original": filtered_relation_id_to_original, + } + + +async def _merge_all_chunks( + filtered_entities: list[dict], + filtered_relations: list[dict], + vector_chunks: list[dict], + query: str = "", + knowledge_graph_inst: BaseGraphStorage = None, + text_chunks_db: BaseKVStorage = None, + query_param: QueryParam = None, + chunks_vdb: BaseVectorStorage = None, + chunk_tracking: dict = None, + query_embedding: list[float] = None, +) -> list[dict]: + """ + Merge chunks from different sources: vector_chunks + entity_chunks + relation_chunks. + """ + if chunk_tracking is None: + chunk_tracking = {} + + # Get chunks from entities + entity_chunks = [] + if filtered_entities and text_chunks_db: + entity_chunks = await _find_related_text_unit_from_entities( + filtered_entities, + query_param, + text_chunks_db, + knowledge_graph_inst, + query, + chunks_vdb, + chunk_tracking=chunk_tracking, + query_embedding=query_embedding, + ) + + # Get chunks from relations + relation_chunks = [] + if filtered_relations and text_chunks_db: + relation_chunks = await _find_related_text_unit_from_relations( + filtered_relations, + query_param, + text_chunks_db, + entity_chunks, # For deduplication + query, + chunks_vdb, + chunk_tracking=chunk_tracking, + query_embedding=query_embedding, + ) + + # Round-robin merge chunks from different sources with deduplication + merged_chunks = [] + seen_chunk_ids = set() + max_len = max(len(vector_chunks), len(entity_chunks), len(relation_chunks)) + origin_len = len(vector_chunks) + len(entity_chunks) + len(relation_chunks) + + for i in range(max_len): + # Add from vector chunks first (Naive mode) + if i < len(vector_chunks): + chunk = vector_chunks[i] + chunk_id = chunk.get("chunk_id") or chunk.get("id") + if chunk_id and chunk_id not in seen_chunk_ids: + seen_chunk_ids.add(chunk_id) + merged_chunks.append( + { + "content": chunk["content"], + "file_path": chunk.get("file_path", "unknown_source"), + "chunk_id": chunk_id, + } + ) + + # Add from entity chunks (Local mode) + if i < len(entity_chunks): + chunk = entity_chunks[i] + chunk_id = chunk.get("chunk_id") or chunk.get("id") + if chunk_id and chunk_id not in seen_chunk_ids: + seen_chunk_ids.add(chunk_id) + merged_chunks.append( + { + "content": chunk["content"], + "file_path": chunk.get("file_path", "unknown_source"), + "chunk_id": chunk_id, + } + ) + + # Add from relation chunks (Global mode) + if i < len(relation_chunks): + chunk = relation_chunks[i] + chunk_id = chunk.get("chunk_id") or chunk.get("id") + if chunk_id and chunk_id not in seen_chunk_ids: + seen_chunk_ids.add(chunk_id) + merged_chunks.append( + { + "content": chunk["content"], + "file_path": chunk.get("file_path", "unknown_source"), + "chunk_id": chunk_id, + } + ) + + logger.info( + f"Round-robin merged chunks: {origin_len} -> {len(merged_chunks)} (deduplicated {origin_len - len(merged_chunks)})" + ) + + return merged_chunks + + +async def _build_context_str( + entities_context: list[dict], + relations_context: list[dict], + merged_chunks: list[dict], + query: str, + query_param: QueryParam, + global_config: dict[str, str], + chunk_tracking: dict = None, + entity_id_to_original: dict = None, + relation_id_to_original: dict = None, +) -> tuple[str, dict[str, Any]]: + """ + Build the final LLM context string with token processing. + This includes dynamic token calculation and final chunk truncation. + """ + tokenizer = global_config.get("tokenizer") + if not tokenizer: + logger.error("Missing tokenizer, cannot build LLM context") + # Return empty raw data structure when no tokenizer + empty_raw_data = convert_to_user_format( + [], + [], + [], + [], + query_param.mode, + ) + empty_raw_data["status"] = "failure" + empty_raw_data["message"] = "Missing tokenizer, cannot build LLM context." + return "", empty_raw_data + + # Get token limits + max_total_tokens = getattr( + query_param, + "max_total_tokens", + global_config.get("max_total_tokens", DEFAULT_MAX_TOTAL_TOKENS), + ) + + # Get the system prompt template from PROMPTS or global_config + sys_prompt_template = global_config.get( + "system_prompt_template", PROMPTS["rag_response"] + ) + + kg_context_template = PROMPTS["kg_query_context"] + user_prompt = query_param.user_prompt if query_param.user_prompt else "" + response_type = ( + query_param.response_type + if query_param.response_type + else "Multiple Paragraphs" + ) + + entities_str = "\n".join( + json.dumps(entity, ensure_ascii=False) for entity in entities_context + ) + relations_str = "\n".join( + json.dumps(relation, ensure_ascii=False) for relation in relations_context + ) + + # Calculate preliminary kg context tokens + pre_kg_context = kg_context_template.format( + entities_str=entities_str, + relations_str=relations_str, + text_chunks_str="", + reference_list_str="", + ) + kg_context_tokens = len(tokenizer.encode(pre_kg_context)) + + # Calculate preliminary system prompt tokens + pre_sys_prompt = sys_prompt_template.format( + context_data="", # Empty for overhead calculation + response_type=response_type, + user_prompt=user_prompt, + ) + sys_prompt_tokens = len(tokenizer.encode(pre_sys_prompt)) + + # Calculate available tokens for text chunks + query_tokens = len(tokenizer.encode(query)) + buffer_tokens = 200 # reserved for reference list and safety buffer + available_chunk_tokens = max_total_tokens - ( + sys_prompt_tokens + kg_context_tokens + query_tokens + buffer_tokens + ) + + logger.debug( + f"Token allocation - Total: {max_total_tokens}, SysPrompt: {sys_prompt_tokens}, Query: {query_tokens}, KG: {kg_context_tokens}, Buffer: {buffer_tokens}, Available for chunks: {available_chunk_tokens}" + ) + + # Apply token truncation to chunks using the dynamic limit + truncated_chunks = await process_chunks_unified( + query=query, + unique_chunks=merged_chunks, + query_param=query_param, + global_config=global_config, + source_type=query_param.mode, + chunk_token_limit=available_chunk_tokens, # Pass dynamic limit + ) + + # Generate reference list from truncated chunks using the new common function + reference_list, truncated_chunks = generate_reference_list_from_chunks( + truncated_chunks + ) + + # Rebuild chunks_context with truncated chunks + # The actual tokens may be slightly less than available_chunk_tokens due to deduplication logic + chunks_context = [] + for i, chunk in enumerate(truncated_chunks): + chunks_context.append( + { + "reference_id": chunk["reference_id"], + "content": chunk["content"], + } + ) + + text_units_str = "\n".join( + json.dumps(text_unit, ensure_ascii=False) for text_unit in chunks_context + ) + reference_list_str = "\n".join( + f"[{ref['reference_id']}] {ref['file_path']}" + for ref in reference_list + if ref["reference_id"] + ) + + logger.info( + f"Final context: {len(entities_context)} entities, {len(relations_context)} relations, {len(chunks_context)} chunks" + ) + + # not necessary to use LLM to generate a response + if not entities_context and not relations_context and not chunks_context: + # Return empty raw data structure when no entities/relations + empty_raw_data = convert_to_user_format( + [], + [], + [], + [], + query_param.mode, + ) + empty_raw_data["status"] = "failure" + empty_raw_data["message"] = "Query returned empty dataset." + return "", empty_raw_data + + # output chunks tracking infomations + # format: / (e.g., E5/2 R2/1 C1/1) + if truncated_chunks and chunk_tracking: + chunk_tracking_log = [] + for chunk in truncated_chunks: + chunk_id = chunk.get("chunk_id") + if chunk_id and chunk_id in chunk_tracking: + tracking_info = chunk_tracking[chunk_id] + source = tracking_info["source"] + frequency = tracking_info["frequency"] + order = tracking_info["order"] + chunk_tracking_log.append(f"{source}{frequency}/{order}") + else: + chunk_tracking_log.append("?0/0") + + if chunk_tracking_log: + logger.info(f"Final chunks S+F/O: {' '.join(chunk_tracking_log)}") + + result = kg_context_template.format( + entities_str=entities_str, + relations_str=relations_str, + text_chunks_str=text_units_str, + reference_list_str=reference_list_str, + ) + + # Always return both context and complete data structure (unified approach) + logger.debug( + f"[_build_context_str] Converting to user format: {len(entities_context)} entities, {len(relations_context)} relations, {len(truncated_chunks)} chunks" + ) + final_data = convert_to_user_format( + entities_context, + relations_context, + truncated_chunks, + reference_list, + query_param.mode, + entity_id_to_original, + relation_id_to_original, + ) + logger.debug( + f"[_build_context_str] Final data after conversion: {len(final_data.get('entities', []))} entities, {len(final_data.get('relationships', []))} relationships, {len(final_data.get('chunks', []))} chunks" + ) + return result, final_data + + +# Now let's update the old _build_query_context to use the new architecture +async def _build_query_context( + query: str, + ll_keywords: str, + hl_keywords: str, + knowledge_graph_inst: BaseGraphStorage, + entities_vdb: BaseVectorStorage, + relationships_vdb: BaseVectorStorage, + text_chunks_db: BaseKVStorage, + query_param: QueryParam, + chunks_vdb: BaseVectorStorage = None, +) -> QueryContextResult | None: + """ + Main query context building function using the new 4-stage architecture: + 1. Search -> 2. Truncate -> 3. Merge chunks -> 4. Build LLM context + + Returns unified QueryContextResult containing both context and raw_data. + """ + + if not query: + logger.warning("Query is empty, skipping context building") + return None + + # Stage 1: Pure search + search_result = await _perform_kg_search( + query, + ll_keywords, + hl_keywords, + knowledge_graph_inst, + entities_vdb, + relationships_vdb, + text_chunks_db, + query_param, + chunks_vdb, + ) + + if not search_result["final_entities"] and not search_result["final_relations"]: + if query_param.mode != "mix": + return None + else: + if not search_result["chunk_tracking"]: + return None + + # Stage 2: Apply token truncation for LLM efficiency + truncation_result = await _apply_token_truncation( + search_result, + query_param, + text_chunks_db.global_config, + ) + + # Stage 3: Merge chunks using filtered entities/relations + merged_chunks = await _merge_all_chunks( + filtered_entities=truncation_result["filtered_entities"], + filtered_relations=truncation_result["filtered_relations"], + vector_chunks=search_result["vector_chunks"], + query=query, + knowledge_graph_inst=knowledge_graph_inst, + text_chunks_db=text_chunks_db, + query_param=query_param, + chunks_vdb=chunks_vdb, + chunk_tracking=search_result["chunk_tracking"], + query_embedding=search_result["query_embedding"], + ) + + if ( + not merged_chunks + and not truncation_result["entities_context"] + and not truncation_result["relations_context"] + ): + return None + + # Stage 4: Build final LLM context with dynamic token processing + # _build_context_str now always returns tuple[str, dict] + context, raw_data = await _build_context_str( + entities_context=truncation_result["entities_context"], + relations_context=truncation_result["relations_context"], + merged_chunks=merged_chunks, + query=query, + query_param=query_param, + global_config=text_chunks_db.global_config, + chunk_tracking=search_result["chunk_tracking"], + entity_id_to_original=truncation_result["entity_id_to_original"], + relation_id_to_original=truncation_result["relation_id_to_original"], + ) + + # Convert keywords strings to lists and add complete metadata to raw_data + hl_keywords_list = hl_keywords.split(", ") if hl_keywords else [] + ll_keywords_list = ll_keywords.split(", ") if ll_keywords else [] + + # Add complete metadata to raw_data (preserve existing metadata including query_mode) + if "metadata" not in raw_data: + raw_data["metadata"] = {} + + # Update keywords while preserving existing metadata + raw_data["metadata"]["keywords"] = { + "high_level": hl_keywords_list, + "low_level": ll_keywords_list, + } + raw_data["metadata"]["processing_info"] = { + "total_entities_found": len(search_result.get("final_entities", [])), + "total_relations_found": len(search_result.get("final_relations", [])), + "entities_after_truncation": len( + truncation_result.get("filtered_entities", []) + ), + "relations_after_truncation": len( + truncation_result.get("filtered_relations", []) + ), + "merged_chunks_count": len(merged_chunks), + "final_chunks_count": len(raw_data.get("data", {}).get("chunks", [])), + } + + logger.debug( + f"[_build_query_context] Context length: {len(context) if context else 0}" + ) + logger.debug( + f"[_build_query_context] Raw data entities: {len(raw_data.get('data', {}).get('entities', []))}, relationships: {len(raw_data.get('data', {}).get('relationships', []))}, chunks: {len(raw_data.get('data', {}).get('chunks', []))}" + ) + + return QueryContextResult(context=context, raw_data=raw_data) + + +async def _get_node_data( + query: str, + knowledge_graph_inst: BaseGraphStorage, + entities_vdb: BaseVectorStorage, + query_param: QueryParam, +): + # get similar entities + logger.info( + f"Query nodes: {query} (top_k:{query_param.top_k}, cosine:{entities_vdb.cosine_better_than_threshold})" + ) + + results = await entities_vdb.query(query, top_k=query_param.top_k) + + if not len(results): + return [], [] + + # Extract all entity IDs from your results list + node_ids = [r["entity_name"] for r in results] + + # Call the batch node retrieval and degree functions concurrently. + nodes_dict, degrees_dict = await asyncio.gather( + knowledge_graph_inst.get_nodes_batch(node_ids), + knowledge_graph_inst.node_degrees_batch(node_ids), + ) + + # Now, if you need the node data and degree in order: + node_datas = [nodes_dict.get(nid) for nid in node_ids] + node_degrees = [degrees_dict.get(nid, 0) for nid in node_ids] + + if not all([n is not None for n in node_datas]): + logger.warning("Some nodes are missing, maybe the storage is damaged") + + node_datas = [ + { + **n, + "entity_name": k["entity_name"], + "rank": d, + "created_at": k.get("created_at"), + } + for k, n, d in zip(results, node_datas, node_degrees) + if n is not None + ] + + use_relations = await _find_most_related_edges_from_entities( + node_datas, + query_param, + knowledge_graph_inst, + ) + + logger.info( + f"Local query: {len(node_datas)} entites, {len(use_relations)} relations" + ) + + # Entities are sorted by cosine similarity + # Relations are sorted by rank + weight + return node_datas, use_relations + + +async def _find_most_related_edges_from_entities( + node_datas: list[dict], + query_param: QueryParam, + knowledge_graph_inst: BaseGraphStorage, +): + node_names = [dp["entity_name"] for dp in node_datas] + batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names) + + all_edges = [] + seen = set() + + for node_name in node_names: + this_edges = batch_edges_dict.get(node_name, []) + for e in this_edges: + sorted_edge = tuple(sorted(e)) + if sorted_edge not in seen: + seen.add(sorted_edge) + all_edges.append(sorted_edge) + + # Prepare edge pairs in two forms: + # For the batch edge properties function, use dicts. + edge_pairs_dicts = [{"src": e[0], "tgt": e[1]} for e in all_edges] + # For edge degrees, use tuples. + edge_pairs_tuples = list(all_edges) # all_edges is already a list of tuples + + # Call the batched functions concurrently. + edge_data_dict, edge_degrees_dict = await asyncio.gather( + knowledge_graph_inst.get_edges_batch(edge_pairs_dicts), + knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples), + ) + + # Reconstruct edge_datas list in the same order as the deduplicated results. + all_edges_data = [] + for pair in all_edges: + edge_props = edge_data_dict.get(pair) + if edge_props is not None: + if "weight" not in edge_props: + logger.warning( + f"Edge {pair} missing 'weight' attribute, using default value 1.0" + ) + edge_props["weight"] = 1.0 + + combined = { + "src_tgt": pair, + "rank": edge_degrees_dict.get(pair, 0), + **edge_props, + } + all_edges_data.append(combined) + + all_edges_data = sorted( + all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True + ) + + return all_edges_data + + +async def _find_related_text_unit_from_entities( + node_datas: list[dict], + query_param: QueryParam, + text_chunks_db: BaseKVStorage, + knowledge_graph_inst: BaseGraphStorage, + query: str = None, + chunks_vdb: BaseVectorStorage = None, + chunk_tracking: dict = None, + query_embedding=None, +): + """ + Find text chunks related to entities using configurable chunk selection method. + + This function supports two chunk selection strategies: + 1. WEIGHT: Linear gradient weighted polling based on chunk occurrence count + 2. VECTOR: Vector similarity-based selection using embedding cosine similarity + """ + logger.debug(f"Finding text chunks from {len(node_datas)} entities") + + if not node_datas: + return [] + + # Step 1: Collect all text chunks for each entity + entities_with_chunks = [] + for entity in node_datas: + if entity.get("source_id"): + chunks = split_string_by_multi_markers( + entity["source_id"], [GRAPH_FIELD_SEP] + ) + if chunks: + entities_with_chunks.append( + { + "entity_name": entity["entity_name"], + "chunks": chunks, + "entity_data": entity, + } + ) + + if not entities_with_chunks: + logger.warning("No entities with text chunks found") + return [] + + kg_chunk_pick_method = text_chunks_db.global_config.get( + "kg_chunk_pick_method", DEFAULT_KG_CHUNK_PICK_METHOD + ) + max_related_chunks = text_chunks_db.global_config.get( + "related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER + ) + + # Step 2: Count chunk occurrences and deduplicate (keep chunks from earlier positioned entities) + chunk_occurrence_count = {} + for entity_info in entities_with_chunks: + deduplicated_chunks = [] + for chunk_id in entity_info["chunks"]: + chunk_occurrence_count[chunk_id] = ( + chunk_occurrence_count.get(chunk_id, 0) + 1 + ) + + # If this is the first occurrence (count == 1), keep it; otherwise skip (duplicate from later position) + if chunk_occurrence_count[chunk_id] == 1: + deduplicated_chunks.append(chunk_id) + # count > 1 means this chunk appeared in an earlier entity, so skip it + + # Update entity's chunks to deduplicated chunks + entity_info["chunks"] = deduplicated_chunks + + # Step 3: Sort chunks for each entity by occurrence count (higher count = higher priority) + total_entity_chunks = 0 + for entity_info in entities_with_chunks: + sorted_chunks = sorted( + entity_info["chunks"], + key=lambda chunk_id: chunk_occurrence_count.get(chunk_id, 0), + reverse=True, + ) + entity_info["sorted_chunks"] = sorted_chunks + total_entity_chunks += len(sorted_chunks) + + selected_chunk_ids = [] # Initialize to avoid UnboundLocalError + + # Step 4: Apply the selected chunk selection algorithm + # Pick by vector similarity: + # The order of text chunks aligns with the naive retrieval's destination. + # When reranking is disabled, the text chunks delivered to the LLM tend to favor naive retrieval. + if kg_chunk_pick_method == "VECTOR" and query and chunks_vdb: + num_of_chunks = int(max_related_chunks * len(entities_with_chunks) / 2) + + # Get embedding function from global config + actual_embedding_func = text_chunks_db.embedding_func + if not actual_embedding_func: + logger.warning("No embedding function found, falling back to WEIGHT method") + kg_chunk_pick_method = "WEIGHT" + else: + try: + selected_chunk_ids = await pick_by_vector_similarity( + query=query, + text_chunks_storage=text_chunks_db, + chunks_vdb=chunks_vdb, + num_of_chunks=num_of_chunks, + entity_info=entities_with_chunks, + embedding_func=actual_embedding_func, + query_embedding=query_embedding, + ) + + if selected_chunk_ids == []: + kg_chunk_pick_method = "WEIGHT" + logger.warning( + "No entity-related chunks selected by vector similarity, falling back to WEIGHT method" + ) + else: + logger.info( + f"Selecting {len(selected_chunk_ids)} from {total_entity_chunks} entity-related chunks by vector similarity" + ) + + except Exception as e: + logger.error( + f"Error in vector similarity sorting: {e}, falling back to WEIGHT method" + ) + kg_chunk_pick_method = "WEIGHT" + + if kg_chunk_pick_method == "WEIGHT": + # Pick by entity and chunk weight: + # When reranking is disabled, delivered more solely KG related chunks to the LLM + selected_chunk_ids = pick_by_weighted_polling( + entities_with_chunks, max_related_chunks, min_related_chunks=1 + ) + + logger.info( + f"Selecting {len(selected_chunk_ids)} from {total_entity_chunks} entity-related chunks by weighted polling" + ) + + if not selected_chunk_ids: + return [] + + # Step 5: Batch retrieve chunk data + unique_chunk_ids = list( + dict.fromkeys(selected_chunk_ids) + ) # Remove duplicates while preserving order + chunk_data_list = await text_chunks_db.get_by_ids(unique_chunk_ids) + + # Step 6: Build result chunks with valid data and update chunk tracking + result_chunks = [] + for i, (chunk_id, chunk_data) in enumerate(zip(unique_chunk_ids, chunk_data_list)): + if chunk_data is not None and "content" in chunk_data: + chunk_data_copy = chunk_data.copy() + chunk_data_copy["source_type"] = "entity" + chunk_data_copy["chunk_id"] = chunk_id # Add chunk_id for deduplication + result_chunks.append(chunk_data_copy) + + # Update chunk tracking if provided + if chunk_tracking is not None: + chunk_tracking[chunk_id] = { + "source": "E", + "frequency": chunk_occurrence_count.get(chunk_id, 1), + "order": i + 1, # 1-based order in final entity-related results + } + + return result_chunks + + +async def _get_edge_data( + keywords, + knowledge_graph_inst: BaseGraphStorage, + relationships_vdb: BaseVectorStorage, + query_param: QueryParam, +): + logger.info( + f"Query edges: {keywords} (top_k:{query_param.top_k}, cosine:{relationships_vdb.cosine_better_than_threshold})" + ) + + results = await relationships_vdb.query(keywords, top_k=query_param.top_k) + + if not len(results): + return [], [] + + # Prepare edge pairs in two forms: + # For the batch edge properties function, use dicts. + edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results] + edge_data_dict = await knowledge_graph_inst.get_edges_batch(edge_pairs_dicts) + + # Reconstruct edge_datas list in the same order as results. + edge_datas = [] + for k in results: + pair = (k["src_id"], k["tgt_id"]) + edge_props = edge_data_dict.get(pair) + if edge_props is not None: + if "weight" not in edge_props: + logger.warning( + f"Edge {pair} missing 'weight' attribute, using default value 1.0" + ) + edge_props["weight"] = 1.0 + + # Keep edge data without rank, maintain vector search order + combined = { + "src_id": k["src_id"], + "tgt_id": k["tgt_id"], + "created_at": k.get("created_at", None), + **edge_props, + } + edge_datas.append(combined) + + # Relations maintain vector search order (sorted by similarity) + + use_entities = await _find_most_related_entities_from_relationships( + edge_datas, + query_param, + knowledge_graph_inst, + ) + + logger.info( + f"Global query: {len(use_entities)} entites, {len(edge_datas)} relations" + ) + + return edge_datas, use_entities + + +async def _find_most_related_entities_from_relationships( + edge_datas: list[dict], + query_param: QueryParam, + knowledge_graph_inst: BaseGraphStorage, +): + entity_names = [] + seen = set() + + for e in edge_datas: + if e["src_id"] not in seen: + entity_names.append(e["src_id"]) + seen.add(e["src_id"]) + if e["tgt_id"] not in seen: + entity_names.append(e["tgt_id"]) + seen.add(e["tgt_id"]) + + # Only get nodes data, no need for node degrees + nodes_dict = await knowledge_graph_inst.get_nodes_batch(entity_names) + + # Rebuild the list in the same order as entity_names + node_datas = [] + for entity_name in entity_names: + node = nodes_dict.get(entity_name) + if node is None: + logger.warning(f"Node '{entity_name}' not found in batch retrieval.") + continue + # Combine the node data with the entity name, no rank needed + combined = {**node, "entity_name": entity_name} + node_datas.append(combined) + + return node_datas + + +async def _find_related_text_unit_from_relations( + edge_datas: list[dict], + query_param: QueryParam, + text_chunks_db: BaseKVStorage, + entity_chunks: list[dict] = None, + query: str = None, + chunks_vdb: BaseVectorStorage = None, + chunk_tracking: dict = None, + query_embedding=None, +): + """ + Find text chunks related to relationships using configurable chunk selection method. + + This function supports two chunk selection strategies: + 1. WEIGHT: Linear gradient weighted polling based on chunk occurrence count + 2. VECTOR: Vector similarity-based selection using embedding cosine similarity + """ + logger.debug(f"Finding text chunks from {len(edge_datas)} relations") + + if not edge_datas: + return [] + + # Step 1: Collect all text chunks for each relationship + relations_with_chunks = [] + for relation in edge_datas: + if relation.get("source_id"): + chunks = split_string_by_multi_markers( + relation["source_id"], [GRAPH_FIELD_SEP] + ) + if chunks: + # Build relation identifier + if "src_tgt" in relation: + rel_key = tuple(sorted(relation["src_tgt"])) + else: + rel_key = tuple( + sorted([relation.get("src_id"), relation.get("tgt_id")]) + ) + + relations_with_chunks.append( + { + "relation_key": rel_key, + "chunks": chunks, + "relation_data": relation, + } + ) + + if not relations_with_chunks: + logger.warning("No relation-related chunks found") + return [] + + kg_chunk_pick_method = text_chunks_db.global_config.get( + "kg_chunk_pick_method", DEFAULT_KG_CHUNK_PICK_METHOD + ) + max_related_chunks = text_chunks_db.global_config.get( + "related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER + ) + + # Step 2: Count chunk occurrences and deduplicate (keep chunks from earlier positioned relationships) + # Also remove duplicates with entity_chunks + + # Extract chunk IDs from entity_chunks for deduplication + entity_chunk_ids = set() + if entity_chunks: + for chunk in entity_chunks: + chunk_id = chunk.get("chunk_id") + if chunk_id: + entity_chunk_ids.add(chunk_id) + + chunk_occurrence_count = {} + # Track unique chunk_ids that have been removed to avoid double counting + removed_entity_chunk_ids = set() + + for relation_info in relations_with_chunks: + deduplicated_chunks = [] + for chunk_id in relation_info["chunks"]: + # Skip chunks that already exist in entity_chunks + if chunk_id in entity_chunk_ids: + # Only count each unique chunk_id once + removed_entity_chunk_ids.add(chunk_id) + continue + + chunk_occurrence_count[chunk_id] = ( + chunk_occurrence_count.get(chunk_id, 0) + 1 + ) + + # If this is the first occurrence (count == 1), keep it; otherwise skip (duplicate from later position) + if chunk_occurrence_count[chunk_id] == 1: + deduplicated_chunks.append(chunk_id) + # count > 1 means this chunk appeared in an earlier relationship, so skip it + + # Update relationship's chunks to deduplicated chunks + relation_info["chunks"] = deduplicated_chunks + + # Check if any relations still have chunks after deduplication + relations_with_chunks = [ + relation_info + for relation_info in relations_with_chunks + if relation_info["chunks"] + ] + + if not relations_with_chunks: + logger.info( + f"Find no additional relations-related chunks from {len(edge_datas)} relations" + ) + return [] + + # Step 3: Sort chunks for each relationship by occurrence count (higher count = higher priority) + total_relation_chunks = 0 + for relation_info in relations_with_chunks: + sorted_chunks = sorted( + relation_info["chunks"], + key=lambda chunk_id: chunk_occurrence_count.get(chunk_id, 0), + reverse=True, + ) + relation_info["sorted_chunks"] = sorted_chunks + total_relation_chunks += len(sorted_chunks) + + logger.info( + f"Find {total_relation_chunks} additional chunks in {len(relations_with_chunks)} relations (deduplicated {len(removed_entity_chunk_ids)})" + ) + + # Step 4: Apply the selected chunk selection algorithm + selected_chunk_ids = [] # Initialize to avoid UnboundLocalError + + if kg_chunk_pick_method == "VECTOR" and query and chunks_vdb: + num_of_chunks = int(max_related_chunks * len(relations_with_chunks) / 2) + + # Get embedding function from global config + actual_embedding_func = text_chunks_db.embedding_func + if not actual_embedding_func: + logger.warning("No embedding function found, falling back to WEIGHT method") + kg_chunk_pick_method = "WEIGHT" + else: + try: + selected_chunk_ids = await pick_by_vector_similarity( + query=query, + text_chunks_storage=text_chunks_db, + chunks_vdb=chunks_vdb, + num_of_chunks=num_of_chunks, + entity_info=relations_with_chunks, + embedding_func=actual_embedding_func, + query_embedding=query_embedding, + ) + + if selected_chunk_ids == []: + kg_chunk_pick_method = "WEIGHT" + logger.warning( + "No relation-related chunks selected by vector similarity, falling back to WEIGHT method" + ) + else: + logger.info( + f"Selecting {len(selected_chunk_ids)} from {total_relation_chunks} relation-related chunks by vector similarity" + ) + + except Exception as e: + logger.error( + f"Error in vector similarity sorting: {e}, falling back to WEIGHT method" + ) + kg_chunk_pick_method = "WEIGHT" + + if kg_chunk_pick_method == "WEIGHT": + # Apply linear gradient weighted polling algorithm + selected_chunk_ids = pick_by_weighted_polling( + relations_with_chunks, max_related_chunks, min_related_chunks=1 + ) + + logger.info( + f"Selecting {len(selected_chunk_ids)} from {total_relation_chunks} relation-related chunks by weighted polling" + ) + + logger.debug( + f"KG related chunks: {len(entity_chunks)} from entitys, {len(selected_chunk_ids)} from relations" + ) + + if not selected_chunk_ids: + return [] + + # Step 5: Batch retrieve chunk data + unique_chunk_ids = list( + dict.fromkeys(selected_chunk_ids) + ) # Remove duplicates while preserving order + chunk_data_list = await text_chunks_db.get_by_ids(unique_chunk_ids) + + # Step 6: Build result chunks with valid data and update chunk tracking + result_chunks = [] + for i, (chunk_id, chunk_data) in enumerate(zip(unique_chunk_ids, chunk_data_list)): + if chunk_data is not None and "content" in chunk_data: + chunk_data_copy = chunk_data.copy() + chunk_data_copy["source_type"] = "relationship" + chunk_data_copy["chunk_id"] = chunk_id # Add chunk_id for deduplication + result_chunks.append(chunk_data_copy) + + # Update chunk tracking if provided + if chunk_tracking is not None: + chunk_tracking[chunk_id] = { + "source": "R", + "frequency": chunk_occurrence_count.get(chunk_id, 1), + "order": i + 1, # 1-based order in final relation-related results + } + + return result_chunks + + +@overload +async def naive_query( + query: str, + chunks_vdb: BaseVectorStorage, + query_param: QueryParam, + global_config: dict[str, str], + hashing_kv: BaseKVStorage | None = None, + system_prompt: str | None = None, + return_raw_data: Literal[True] = True, +) -> dict[str, Any]: ... + + +@overload +async def naive_query( + query: str, + chunks_vdb: BaseVectorStorage, + query_param: QueryParam, + global_config: dict[str, str], + hashing_kv: BaseKVStorage | None = None, + system_prompt: str | None = None, + return_raw_data: Literal[False] = False, +) -> str | AsyncIterator[str]: ... + + +async def naive_query( + query: str, + chunks_vdb: BaseVectorStorage, + query_param: QueryParam, + global_config: dict[str, str], + hashing_kv: BaseKVStorage | None = None, + system_prompt: str | None = None, +) -> QueryResult | None: + """ + Execute naive query and return unified QueryResult object. + + Args: + query: Query string + chunks_vdb: Document chunks vector database + query_param: Query parameters + global_config: Global configuration + hashing_kv: Cache storage + system_prompt: System prompt + + Returns: + QueryResult | None: Unified query result object containing: + - content: Non-streaming response text content + - response_iterator: Streaming response iterator + - raw_data: Complete structured data (including references and metadata) + - is_streaming: Whether this is a streaming result + + Returns None when no relevant chunks are retrieved. + """ + + if not query: + return QueryResult(content=PROMPTS["fail_response"]) + + if query_param.model_func: + use_model_func = query_param.model_func + else: + use_model_func = global_config["llm_model_func"] + # Apply higher priority (5) to query relation LLM function + use_model_func = partial(use_model_func, _priority=5) + + tokenizer: Tokenizer = global_config["tokenizer"] + if not tokenizer: + logger.error("Tokenizer not found in global configuration.") + return QueryResult(content=PROMPTS["fail_response"]) + + chunks = await _get_vector_context(query, chunks_vdb, query_param, None) + + if chunks is None or len(chunks) == 0: + logger.info( + "[naive_query] No relevant document chunks found; returning no-result." + ) + return None + + # Calculate dynamic token limit for chunks + max_total_tokens = getattr( + query_param, + "max_total_tokens", + global_config.get("max_total_tokens", DEFAULT_MAX_TOTAL_TOKENS), + ) + + # Calculate system prompt template tokens (excluding content_data) + user_prompt = f"\n\n{query_param.user_prompt}" if query_param.user_prompt else "n/a" + response_type = ( + query_param.response_type + if query_param.response_type + else "Multiple Paragraphs" + ) + + # Use the provided system prompt or default + sys_prompt_template = ( + system_prompt if system_prompt else PROMPTS["naive_rag_response"] + ) + + # Create a preliminary system prompt with empty content_data to calculate overhead + pre_sys_prompt = sys_prompt_template.format( + response_type=response_type, + user_prompt=user_prompt, + content_data="", # Empty for overhead calculation + ) + + # Calculate available tokens for chunks + sys_prompt_tokens = len(tokenizer.encode(pre_sys_prompt)) + query_tokens = len(tokenizer.encode(query)) + buffer_tokens = 200 # reserved for reference list and safety buffer + available_chunk_tokens = max_total_tokens - ( + sys_prompt_tokens + query_tokens + buffer_tokens + ) + + logger.debug( + f"Naive query token allocation - Total: {max_total_tokens}, SysPrompt: {sys_prompt_tokens}, Query: {query_tokens}, Buffer: {buffer_tokens}, Available for chunks: {available_chunk_tokens}" + ) + + # Process chunks using unified processing with dynamic token limit + processed_chunks = await process_chunks_unified( + query=query, + unique_chunks=chunks, + query_param=query_param, + global_config=global_config, + source_type="vector", + chunk_token_limit=available_chunk_tokens, # Pass dynamic limit + ) + + # Generate reference list from processed chunks using the new common function + reference_list, processed_chunks_with_ref_ids = generate_reference_list_from_chunks( + processed_chunks + ) + + logger.info(f"Final context: {len(processed_chunks_with_ref_ids)} chunks") + + # Build raw data structure for naive mode using processed chunks with reference IDs + raw_data = convert_to_user_format( + [], # naive mode has no entities + [], # naive mode has no relationships + processed_chunks_with_ref_ids, + reference_list, + "naive", + ) + + # Add complete metadata for naive mode + if "metadata" not in raw_data: + raw_data["metadata"] = {} + raw_data["metadata"]["keywords"] = { + "high_level": [], # naive mode has no keyword extraction + "low_level": [], # naive mode has no keyword extraction + } + raw_data["metadata"]["processing_info"] = { + "total_chunks_found": len(chunks), + "final_chunks_count": len(processed_chunks_with_ref_ids), + } + + # Build chunks_context from processed chunks with reference IDs + chunks_context = [] + for i, chunk in enumerate(processed_chunks_with_ref_ids): + chunks_context.append( + { + "reference_id": chunk["reference_id"], + "content": chunk["content"], + } + ) + + text_units_str = "\n".join( + json.dumps(text_unit, ensure_ascii=False) for text_unit in chunks_context + ) + reference_list_str = "\n".join( + f"[{ref['reference_id']}] {ref['file_path']}" + for ref in reference_list + if ref["reference_id"] + ) + + naive_context_template = PROMPTS["naive_query_context"] + context_content = naive_context_template.format( + text_chunks_str=text_units_str, + reference_list_str=reference_list_str, + ) + + if query_param.only_need_context and not query_param.only_need_prompt: + return QueryResult(content=context_content, raw_data=raw_data) + + sys_prompt = sys_prompt_template.format( + response_type=query_param.response_type, + user_prompt=user_prompt, + content_data=context_content, + ) + + user_query = query + + if query_param.only_need_prompt: + prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query]) + return QueryResult(content=prompt_content, raw_data=raw_data) + + # Handle cache + args_hash = compute_args_hash( + query_param.mode, + query, + query_param.response_type, + query_param.top_k, + query_param.chunk_top_k, + query_param.max_entity_tokens, + query_param.max_relation_tokens, + query_param.max_total_tokens, + query_param.user_prompt or "", + query_param.enable_rerank, + ) + cached_result = await handle_cache( + hashing_kv, args_hash, user_query, query_param.mode, cache_type="query" + ) + if cached_result is not None: + cached_response, _ = cached_result # Extract content, ignore timestamp + logger.info( + " == LLM cache == Query cache hit, using cached response as query result" + ) + response = cached_response + else: + response = await use_model_func( + user_query, + system_prompt=sys_prompt, + history_messages=query_param.conversation_history, + enable_cot=True, + stream=query_param.stream, + ) + + if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"): + queryparam_dict = { + "mode": query_param.mode, + "response_type": query_param.response_type, + "top_k": query_param.top_k, + "chunk_top_k": query_param.chunk_top_k, + "max_entity_tokens": query_param.max_entity_tokens, + "max_relation_tokens": query_param.max_relation_tokens, + "max_total_tokens": query_param.max_total_tokens, + "user_prompt": query_param.user_prompt or "", + "enable_rerank": query_param.enable_rerank, + } + await save_to_cache( + hashing_kv, + CacheData( + args_hash=args_hash, + content=response, + prompt=query, + mode=query_param.mode, + cache_type="query", + queryparam=queryparam_dict, + ), + ) + + # Return unified result based on actual response type + if isinstance(response, str): + # Non-streaming response (string) + if len(response) > len(sys_prompt): + response = ( + response[len(sys_prompt) :] + .replace(sys_prompt, "") + .replace("user", "") + .replace("model", "") + .replace(query, "") + .replace("", "") + .replace("", "") + .strip() + ) + + return QueryResult(content=response, raw_data=raw_data) + else: + # Streaming response (AsyncIterator) + return QueryResult( + response_iterator=response, raw_data=raw_data, is_streaming=True + ) diff --git a/lightrag/utils-Backup01.py b/lightrag/utils-Backup01.py new file mode 100644 index 0000000000..e5416d09da --- /dev/null +++ b/lightrag/utils-Backup01.py @@ -0,0 +1,3375 @@ +from __future__ import annotations +import weakref + +import sys + +import asyncio +import html +import csv +import inspect +import json +import logging +import logging.handlers +import os +import re +import time +import uuid +from dataclasses import dataclass +from datetime import datetime +from functools import wraps +from hashlib import md5 +from typing import ( + Any, + Protocol, + Callable, + TYPE_CHECKING, + List, + Optional, + Iterable, + Sequence, + Collection, +) +import numpy as np +from dotenv import load_dotenv + +from lightrag.constants import ( + DEFAULT_LOG_MAX_BYTES, + DEFAULT_LOG_BACKUP_COUNT, + DEFAULT_LOG_FILENAME, + GRAPH_FIELD_SEP, + DEFAULT_MAX_TOTAL_TOKENS, + DEFAULT_SOURCE_IDS_LIMIT_METHOD, + VALID_SOURCE_IDS_LIMIT_METHODS, + SOURCE_IDS_LIMIT_METHOD_FIFO, +) + +# Precompile regex pattern for JSON sanitization (module-level, compiled once) +_SURROGATE_PATTERN = re.compile(r"[\uD800-\uDFFF\uFFFE\uFFFF]") + + +class SafeStreamHandler(logging.StreamHandler): + """StreamHandler that gracefully handles closed streams during shutdown. + + This handler prevents "ValueError: I/O operation on closed file" errors + that can occur when pytest or other test frameworks close stdout/stderr + before Python's logging cleanup runs. + """ + + def flush(self): + """Flush the stream, ignoring errors if the stream is closed.""" + try: + super().flush() + except (ValueError, OSError): + # Stream is closed or otherwise unavailable, silently ignore + pass + + def close(self): + """Close the handler, ignoring errors if the stream is already closed.""" + try: + super().close() + except (ValueError, OSError): + # Stream is closed or otherwise unavailable, silently ignore + pass + + +# Initialize logger with basic configuration +logger = logging.getLogger("lightrag") +logger.propagate = False # prevent log message send to root logger +logger.setLevel(logging.INFO) + +# Add console handler if no handlers exist +if not logger.handlers: + console_handler = SafeStreamHandler() + console_handler.setLevel(logging.INFO) + formatter = logging.Formatter("%(levelname)s: %(message)s") + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + +# Set httpx logging level to WARNING +logging.getLogger("httpx").setLevel(logging.WARNING) + + +def _patch_ascii_colors_console_handler() -> None: + """Prevent ascii_colors from printing flush errors during interpreter exit.""" + + try: + from ascii_colors import ConsoleHandler + except ImportError: + return + + if getattr(ConsoleHandler, "_lightrag_patched", False): + return + + original_handle_error = ConsoleHandler.handle_error + + def _safe_handle_error(self, message: str) -> None: # type: ignore[override] + exc_type, _, _ = sys.exc_info() + if exc_type in (ValueError, OSError) and "close" in message.lower(): + return + original_handle_error(self, message) + + ConsoleHandler.handle_error = _safe_handle_error # type: ignore[assignment] + ConsoleHandler._lightrag_patched = True # type: ignore[attr-defined] + + +_patch_ascii_colors_console_handler() + + +# Global import for pypinyin with startup-time logging +try: + import pypinyin + + _PYPINYIN_AVAILABLE = True + # logger.info("pypinyin loaded successfully for Chinese pinyin sorting") +except ImportError: + pypinyin = None + _PYPINYIN_AVAILABLE = False + logger.warning( + "pypinyin is not installed. Chinese pinyin sorting will use simple string sorting." + ) + + +async def safe_vdb_operation_with_exception( + operation: Callable, + operation_name: str, + entity_name: str = "", + max_retries: int = 3, + retry_delay: float = 0.2, + logger_func: Optional[Callable] = None, +) -> None: + """ + Safely execute vector database operations with retry mechanism and exception handling. + + This function ensures that VDB operations are executed with proper error handling + and retry logic. If all retries fail, it raises an exception to maintain data consistency. + + Args: + operation: The async operation to execute + operation_name: Operation name for logging purposes + entity_name: Entity name for logging purposes + max_retries: Maximum number of retry attempts + retry_delay: Delay between retries in seconds + logger_func: Logger function to use for error messages + + Raises: + Exception: When operation fails after all retry attempts + """ + log_func = logger_func or logger.warning + + for attempt in range(max_retries): + try: + await operation() + return # Success, return immediately + except Exception as e: + if attempt >= max_retries - 1: + error_msg = f"VDB {operation_name} failed for {entity_name} after {max_retries} attempts: {e}" + log_func(error_msg) + raise Exception(error_msg) from e + else: + log_func( + f"VDB {operation_name} attempt {attempt + 1} failed for {entity_name}: {e}, retrying..." + ) + if retry_delay > 0: + await asyncio.sleep(retry_delay) + + +def get_env_value( + env_key: str, default: any, value_type: type = str, special_none: bool = False +) -> any: + """ + Get value from environment variable with type conversion + + Args: + env_key (str): Environment variable key + default (any): Default value if env variable is not set + value_type (type): Type to convert the value to + special_none (bool): If True, return None when value is "None" + + Returns: + any: Converted value from environment or default + """ + value = os.getenv(env_key) + if value is None: + return default + + # Handle special case for "None" string + if special_none and value == "None": + return None + + if value_type is bool: + return value.lower() in ("true", "1", "yes", "t", "on") + + # Handle list type with JSON parsing + if value_type is list: + try: + import json + + parsed_value = json.loads(value) + # Ensure the parsed value is actually a list + if isinstance(parsed_value, list): + return parsed_value + else: + logger.warning( + f"Environment variable {env_key} is not a valid JSON list, using default" + ) + return default + except (json.JSONDecodeError, ValueError) as e: + logger.warning( + f"Failed to parse {env_key} as JSON list: {e}, using default" + ) + return default + + try: + return value_type(value) + except (ValueError, TypeError): + return default + + +# Use TYPE_CHECKING to avoid circular imports +if TYPE_CHECKING: + from lightrag.base import BaseKVStorage, BaseVectorStorage, QueryParam + +# use the .env that is inside the current folder +# allows to use different .env file for each lightrag instance +# the OS environment variables take precedence over the .env file +load_dotenv(dotenv_path=".env", override=False) + +VERBOSE_DEBUG = os.getenv("VERBOSE", "false").lower() == "true" + + +def verbose_debug(msg: str, *args, **kwargs): + """Function for outputting detailed debug information. + When VERBOSE_DEBUG=True, outputs the complete message. + When VERBOSE_DEBUG=False, outputs only the first 50 characters. + + Args: + msg: The message format string + *args: Arguments to be formatted into the message + **kwargs: Keyword arguments passed to logger.debug() + """ + if VERBOSE_DEBUG: + logger.debug(msg, *args, **kwargs) + else: + # Format the message with args first + if args: + formatted_msg = msg % args + else: + formatted_msg = msg + # Then truncate the formatted message + truncated_msg = ( + formatted_msg[:150] + "..." if len(formatted_msg) > 150 else formatted_msg + ) + # Remove consecutive newlines + truncated_msg = re.sub(r"\n+", "\n", truncated_msg) + logger.debug(truncated_msg, **kwargs) + + +def set_verbose_debug(enabled: bool): + """Enable or disable verbose debug output""" + global VERBOSE_DEBUG + VERBOSE_DEBUG = enabled + + +statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0} + + +class LightragPathFilter(logging.Filter): + """Filter for lightrag logger to filter out frequent path access logs""" + + def __init__(self): + super().__init__() + # Define paths to be filtered + self.filtered_paths = [ + "/documents", + "/documents/paginated", + "/health", + "/webui/", + "/documents/pipeline_status", + ] + # self.filtered_paths = ["/health", "/webui/"] + + def filter(self, record): + try: + # Check if record has the required attributes for an access log + if not hasattr(record, "args") or not isinstance(record.args, tuple): + return True + if len(record.args) < 5: + return True + + # Extract method, path and status from the record args + method = record.args[1] + path = record.args[2] + status = record.args[4] + + # Filter out successful GET/POST requests to filtered paths + if ( + (method == "GET" or method == "POST") + and (status == 200 or status == 304) + and path in self.filtered_paths + ): + return False + + return True + except Exception: + # In case of any error, let the message through + return True + + +def setup_logger( + logger_name: str, + level: str = "INFO", + add_filter: bool = False, + log_file_path: str | None = None, + enable_file_logging: bool = True, +): + """Set up a logger with console and optionally file handlers + + Args: + logger_name: Name of the logger to set up + level: Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL) + add_filter: Whether to add LightragPathFilter to the logger + log_file_path: Path to the log file. If None and file logging is enabled, defaults to lightrag.log in LOG_DIR or cwd + enable_file_logging: Whether to enable logging to a file (defaults to True) + """ + # Configure formatters + detailed_formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + simple_formatter = logging.Formatter("%(levelname)s: %(message)s") + + logger_instance = logging.getLogger(logger_name) + logger_instance.setLevel(level) + logger_instance.handlers = [] # Clear existing handlers + logger_instance.propagate = False + + # Add console handler with safe stream handling + console_handler = SafeStreamHandler() + console_handler.setFormatter(simple_formatter) + console_handler.setLevel(level) + logger_instance.addHandler(console_handler) + + # Add file handler by default unless explicitly disabled + if enable_file_logging: + # Get log file path + if log_file_path is None: + log_dir = os.getenv("LOG_DIR", os.getcwd()) + log_file_path = os.path.abspath(os.path.join(log_dir, DEFAULT_LOG_FILENAME)) + + # Ensure log directory exists + os.makedirs(os.path.dirname(log_file_path), exist_ok=True) + + # Get log file max size and backup count from environment variables + log_max_bytes = get_env_value("LOG_MAX_BYTES", DEFAULT_LOG_MAX_BYTES, int) + log_backup_count = get_env_value( + "LOG_BACKUP_COUNT", DEFAULT_LOG_BACKUP_COUNT, int + ) + + try: + # Add file handler + file_handler = logging.handlers.RotatingFileHandler( + filename=log_file_path, + maxBytes=log_max_bytes, + backupCount=log_backup_count, + encoding="utf-8", + ) + file_handler.setFormatter(detailed_formatter) + file_handler.setLevel(level) + logger_instance.addHandler(file_handler) + except PermissionError as e: + logger.warning(f"Could not create log file at {log_file_path}: {str(e)}") + logger.warning("Continuing with console logging only") + + # Add path filter if requested + if add_filter: + path_filter = LightragPathFilter() + logger_instance.addFilter(path_filter) + + +class UnlimitedSemaphore: + """A context manager that allows unlimited access.""" + + async def __aenter__(self): + pass + + async def __aexit__(self, exc_type, exc, tb): + pass + + +@dataclass +class TaskState: + """Task state tracking for priority queue management""" + + future: asyncio.Future + start_time: float + execution_start_time: float = None + worker_started: bool = False + cancellation_requested: bool = False + cleanup_done: bool = False + + +@dataclass +class EmbeddingFunc: + """Embedding function wrapper with dimension validation + + This class wraps an embedding function to ensure that the output embeddings have the correct dimension. + If wrapped multiple times, the inner wrappers will be automatically unwrapped to prevent + configuration conflicts where inner wrapper settings would override outer wrapper settings. + + Using functools.partial for parameter binding: + A common pattern is to use functools.partial to pre-bind model and host parameters + to an embedding function. When the base embedding function is already decorated with + @wrap_embedding_func_with_attrs (e.g., ollama_embed), use `.func` to access the + original unwrapped function to avoid double wrapping: + + Example: + from functools import partial + + # ❌ Wrong - causes double wrapping (inner EmbeddingFunc still executes) + func=partial(ollama_embed, embed_model="bge-m3:latest", host="http://localhost:11434") + + # ✅ Correct - access the unwrapped function via .func + func=partial(ollama_embed.func, embed_model="bge-m3:latest", host="http://localhost:11434") + + Args: + embedding_dim: Expected dimension of the embeddings(For dimension checking and workspace data isolation in vector DB) + func: The actual embedding function to wrap + max_token_size: Enable embedding token limit checking for description summarization(Set embedding_token_limit in LightRAG) + send_dimensions: Whether to inject embedding_dim argument to underlying function + model_name: Model name for implementing workspace data isolation in vector DB + """ + + embedding_dim: int + func: callable + max_token_size: int | None = None + send_dimensions: bool = False + model_name: str | None = ( + None # Model name for implementing workspace data isolation in vector DB + ) + + def __post_init__(self): + """Unwrap nested EmbeddingFunc to prevent double wrapping issues. + + When an EmbeddingFunc wraps another EmbeddingFunc, the inner wrapper's + __call__ preprocessing would override the outer wrapper's settings. + This method detects and unwraps nested EmbeddingFunc instances to ensure + that only the outermost wrapper's configuration is applied. + """ + # Check if func is already an EmbeddingFunc instance and unwrap it + max_unwrap_depth = 3 # Safety limit to prevent infinite loops + unwrap_count = 0 + while isinstance(self.func, EmbeddingFunc): + unwrap_count += 1 + if unwrap_count > max_unwrap_depth: + raise ValueError( + f"EmbeddingFunc unwrap depth exceeded {max_unwrap_depth}. " + "Possible circular reference detected." + ) + # Unwrap to get the original function + self.func = self.func.func + + if unwrap_count > 0: + logger.warning( + f"Detected nested EmbeddingFunc wrapping (depth: {unwrap_count}), " + "auto-unwrapped to prevent configuration conflicts. " + "Consider using .func to access the unwrapped function directly." + ) + + async def __call__(self, *args, **kwargs) -> np.ndarray: + # Only inject embedding_dim when send_dimensions is True + if self.send_dimensions: + # Check if user provided embedding_dim parameter + if "embedding_dim" in kwargs: + user_provided_dim = kwargs["embedding_dim"] + # If user's value differs from class attribute, output warning + if ( + user_provided_dim is not None + and user_provided_dim != self.embedding_dim + ): + logger.warning( + f"Ignoring user-provided embedding_dim={user_provided_dim}, " + f"using declared embedding_dim={self.embedding_dim} from decorator" + ) + + # Inject embedding_dim from decorator + kwargs["embedding_dim"] = self.embedding_dim + + # Check if underlying function supports max_token_size and inject if not provided + if self.max_token_size is not None and "max_token_size" not in kwargs: + sig = inspect.signature(self.func) + if "max_token_size" in sig.parameters: + kwargs["max_token_size"] = self.max_token_size + + # Call the actual embedding function + result = await self.func(*args, **kwargs) + + # Validate embedding dimensions using total element count + total_elements = result.size # Total number of elements in the numpy array + expected_dim = self.embedding_dim + + # Check if total elements can be evenly divided by embedding_dim + if total_elements % expected_dim != 0: + raise ValueError( + f"Embedding dimension mismatch detected: " + f"total elements ({total_elements}) cannot be evenly divided by " + f"expected dimension ({expected_dim}). " + ) + + # Optional: Verify vector count matches input text count + # actual_vectors = total_elements // expected_dim + # if args and isinstance(args[0], (list, tuple)): + # expected_vectors = len(args[0]) + # if actual_vectors != expected_vectors: + # raise ValueError( + # f"Vector count mismatch: " + # f"expected {expected_vectors} vectors but got {actual_vectors} vectors (from embedding result)." + # ) + + # Optional: Verify vector count matches input text count + actual_vectors = total_elements // expected_dim + if args and isinstance(args[0], (list, tuple)): + expected_vectors = len(args[0]) + if actual_vectors != expected_vectors: + # WORKAROUND: Handle the doubling bug in multimodal/mineru pipelines + if actual_vectors == 2 * expected_vectors: + logger.warning( + f"Vector count doubling detected ({actual_vectors} vs {expected_vectors}). Slicing result to match." + ) + # Reshape if it's a 2D array, slice, then return + if len(result.shape) > 1: + return result[:expected_vectors] + else: + # Fallback for flat arrays + new_size = expected_vectors * expected_dim + return result[:new_size] + else: + raise ValueError( + f"Vector count mismatch: " + f"expected {expected_vectors} vectors but got {actual_vectors} vectors (from embedding result)." + ) + + return result + + +def compute_args_hash(*args: Any) -> str: + """Compute a hash for the given arguments with safe Unicode handling. + + Args: + *args: Arguments to hash + Returns: + str: Hash string + """ + # Convert all arguments to strings and join them + args_str = "".join([str(arg) for arg in args]) + + # Use 'replace' error handling to safely encode problematic Unicode characters + # This replaces invalid characters with Unicode replacement character (U+FFFD) + try: + return md5(args_str.encode("utf-8")).hexdigest() + except UnicodeEncodeError: + # Handle surrogate characters and other encoding issues + safe_bytes = args_str.encode("utf-8", errors="replace") + return md5(safe_bytes).hexdigest() + + +def compute_mdhash_id(content: str, prefix: str = "") -> str: + """ + Compute a unique ID for a given content string. + + The ID is a combination of the given prefix and the MD5 hash of the content string. + """ + return prefix + compute_args_hash(content) + + +def generate_cache_key(mode: str, cache_type: str, hash_value: str) -> str: + """Generate a flattened cache key in the format {mode}:{cache_type}:{hash} + + Args: + mode: Cache mode (e.g., 'default', 'local', 'global') + cache_type: Type of cache (e.g., 'extract', 'query', 'keywords') + hash_value: Hash value from compute_args_hash + + Returns: + str: Flattened cache key + """ + return f"{mode}:{cache_type}:{hash_value}" + + +def parse_cache_key(cache_key: str) -> tuple[str, str, str] | None: + """Parse a flattened cache key back into its components + + Args: + cache_key: Flattened cache key in format {mode}:{cache_type}:{hash} + + Returns: + tuple[str, str, str] | None: (mode, cache_type, hash) or None if invalid format + """ + parts = cache_key.split(":", 2) + if len(parts) == 3: + return parts[0], parts[1], parts[2] + return None + + +# Custom exception classes +class QueueFullError(Exception): + """Raised when the queue is full and the wait times out""" + + pass + + +class WorkerTimeoutError(Exception): + """Worker-level timeout exception with specific timeout information""" + + def __init__(self, timeout_value: float, timeout_type: str = "execution"): + self.timeout_value = timeout_value + self.timeout_type = timeout_type + super().__init__(f"Worker {timeout_type} timeout after {timeout_value}s") + + +class HealthCheckTimeoutError(Exception): + """Health Check-level timeout exception""" + + def __init__(self, timeout_value: float, execution_duration: float): + self.timeout_value = timeout_value + self.execution_duration = execution_duration + super().__init__( + f"Task forcefully terminated due to execution timeout (>{timeout_value}s, actual: {execution_duration:.1f}s)" + ) + + +def priority_limit_async_func_call( + max_size: int, + llm_timeout: float = None, + max_execution_timeout: float = None, + max_task_duration: float = None, + max_queue_size: int = 1000, + cleanup_timeout: float = 2.0, + queue_name: str = "limit_async", +): + """ + Enhanced priority-limited asynchronous function call decorator with robust timeout handling + + This decorator provides a comprehensive solution for managing concurrent LLM requests with: + - Multi-layer timeout protection (LLM -> Worker -> Health Check -> User) + - Task state tracking to prevent race conditions + - Enhanced health check system with stuck task detection + - Proper resource cleanup and error recovery + + Args: + max_size: Maximum number of concurrent calls + max_queue_size: Maximum queue capacity to prevent memory overflow + llm_timeout: LLM provider timeout (from global config), used to calculate other timeouts + max_execution_timeout: Maximum time for worker to execute function (defaults to llm_timeout + 30s) + max_task_duration: Maximum time before health check intervenes (defaults to llm_timeout + 60s) + cleanup_timeout: Maximum time to wait for cleanup operations (defaults to 2.0s) + queue_name: Optional queue name for logging identification (defaults to "limit_async") + + Returns: + Decorator function + """ + + def final_decro(func): + # Ensure func is callable + if not callable(func): + raise TypeError(f"Expected a callable object, got {type(func)}") + + # Calculate timeout hierarchy if llm_timeout is provided (Dynamic Timeout Calculation) + if llm_timeout is not None: + nonlocal max_execution_timeout, max_task_duration + if max_execution_timeout is None: + max_execution_timeout = ( + llm_timeout * 2 + ) # Reserved timeout buffer for low-level retry + if max_task_duration is None: + max_task_duration = ( + llm_timeout * 2 + 15 + ) # Reserved timeout buffer for health check phase + + queue = asyncio.PriorityQueue(maxsize=max_queue_size) + tasks = set() + initialization_lock = asyncio.Lock() + counter = 0 + shutdown_event = asyncio.Event() + initialized = False + worker_health_check_task = None + + # Enhanced task state management + task_states = {} # task_id -> TaskState + task_states_lock = asyncio.Lock() + active_futures = weakref.WeakSet() + reinit_count = 0 + + async def worker(): + """Enhanced worker that processes tasks with proper timeout and state management""" + try: + while not shutdown_event.is_set(): + try: + # Get task from queue with timeout for shutdown checking + try: + ( + priority, + count, + task_id, + args, + kwargs, + ) = await asyncio.wait_for(queue.get(), timeout=1.0) + except asyncio.TimeoutError: + continue + + # Get task state and mark worker as started + async with task_states_lock: + if task_id not in task_states: + queue.task_done() + continue + task_state = task_states[task_id] + task_state.worker_started = True + # Record execution start time when worker actually begins processing + task_state.execution_start_time = ( + asyncio.get_event_loop().time() + ) + + # Check if task was cancelled before worker started + if ( + task_state.cancellation_requested + or task_state.future.cancelled() + ): + async with task_states_lock: + task_states.pop(task_id, None) + queue.task_done() + continue + + try: + # Execute function with timeout protection + if max_execution_timeout is not None: + result = await asyncio.wait_for( + func(*args, **kwargs), timeout=max_execution_timeout + ) + else: + result = await func(*args, **kwargs) + + # Set result if future is still valid + if not task_state.future.done(): + task_state.future.set_result(result) + + except asyncio.TimeoutError: + # Worker-level timeout (max_execution_timeout exceeded) + logger.warning( + f"{queue_name}: Worker timeout for task {task_id} after {max_execution_timeout}s" + ) + if not task_state.future.done(): + task_state.future.set_exception( + WorkerTimeoutError( + max_execution_timeout, "execution" + ) + ) + except asyncio.CancelledError: + # Task was cancelled during execution + if not task_state.future.done(): + task_state.future.cancel() + logger.debug( + f"{queue_name}: Task {task_id} cancelled during execution" + ) + except Exception as e: + # Function execution error + logger.error( + f"{queue_name}: Error in decorated function for task {task_id}: {str(e)}" + ) + if not task_state.future.done(): + task_state.future.set_exception(e) + finally: + # Clean up task state + async with task_states_lock: + task_states.pop(task_id, None) + queue.task_done() + + except Exception as e: + # Critical error in worker loop + logger.error( + f"{queue_name}: Critical error in worker: {str(e)}" + ) + await asyncio.sleep(0.1) + finally: + logger.debug(f"{queue_name}: Worker exiting") + + async def enhanced_health_check(): + """Enhanced health check with stuck task detection and recovery""" + nonlocal initialized + try: + while not shutdown_event.is_set(): + await asyncio.sleep(5) # Check every 5 seconds + + current_time = asyncio.get_event_loop().time() + + # Detect and handle stuck tasks based on execution start time + if max_task_duration is not None: + stuck_tasks = [] + async with task_states_lock: + for task_id, task_state in list(task_states.items()): + # Only check tasks that have started execution + if ( + task_state.worker_started + and task_state.execution_start_time is not None + and current_time - task_state.execution_start_time + > max_task_duration + ): + stuck_tasks.append( + ( + task_id, + current_time + - task_state.execution_start_time, + ) + ) + + # Force cleanup of stuck tasks + for task_id, execution_duration in stuck_tasks: + logger.warning( + f"{queue_name}: Detected stuck task {task_id} (execution time: {execution_duration:.1f}s), forcing cleanup" + ) + async with task_states_lock: + if task_id in task_states: + task_state = task_states[task_id] + if not task_state.future.done(): + task_state.future.set_exception( + HealthCheckTimeoutError( + max_task_duration, execution_duration + ) + ) + task_states.pop(task_id, None) + + # Worker recovery logic + current_tasks = set(tasks) + done_tasks = {t for t in current_tasks if t.done()} + tasks.difference_update(done_tasks) + + active_tasks_count = len(tasks) + workers_needed = max_size - active_tasks_count + + if workers_needed > 0: + logger.info( + f"{queue_name}: Creating {workers_needed} new workers" + ) + new_tasks = set() + for _ in range(workers_needed): + task = asyncio.create_task(worker()) + new_tasks.add(task) + task.add_done_callback(tasks.discard) + tasks.update(new_tasks) + + except Exception as e: + logger.error(f"{queue_name}: Error in enhanced health check: {str(e)}") + finally: + logger.debug(f"{queue_name}: Enhanced health check task exiting") + initialized = False + + async def ensure_workers(): + """Ensure worker system is initialized with enhanced error handling""" + nonlocal initialized, worker_health_check_task, tasks, reinit_count + + if initialized: + return + + async with initialization_lock: + if initialized: + return + + if reinit_count > 0: + reinit_count += 1 + logger.warning( + f"{queue_name}: Reinitializing system (count: {reinit_count})" + ) + else: + reinit_count = 1 + + # Clean up completed tasks + current_tasks = set(tasks) + done_tasks = {t for t in current_tasks if t.done()} + tasks.difference_update(done_tasks) + + active_tasks_count = len(tasks) + if active_tasks_count > 0 and reinit_count > 1: + logger.warning( + f"{queue_name}: {active_tasks_count} tasks still running during reinitialization" + ) + + # Create worker tasks + workers_needed = max_size - active_tasks_count + for _ in range(workers_needed): + task = asyncio.create_task(worker()) + tasks.add(task) + task.add_done_callback(tasks.discard) + + # Start enhanced health check + worker_health_check_task = asyncio.create_task(enhanced_health_check()) + + initialized = True + # Log dynamic timeout configuration + timeout_info = [] + if llm_timeout is not None: + timeout_info.append(f"Func: {llm_timeout}s") + if max_execution_timeout is not None: + timeout_info.append(f"Worker: {max_execution_timeout}s") + if max_task_duration is not None: + timeout_info.append(f"Health Check: {max_task_duration}s") + + timeout_str = ( + f"(Timeouts: {', '.join(timeout_info)})" if timeout_info else "" + ) + logger.info( + f"{queue_name}: {workers_needed} new workers initialized {timeout_str}" + ) + + async def shutdown(): + """Gracefully shut down all workers and cleanup resources""" + logger.info(f"{queue_name}: Shutting down priority queue workers") + + shutdown_event.set() + + # Cancel all active futures + for future in list(active_futures): + if not future.done(): + future.cancel() + + # Cancel all pending tasks + async with task_states_lock: + for task_id, task_state in list(task_states.items()): + if not task_state.future.done(): + task_state.future.cancel() + task_states.clear() + + # Wait for queue to empty with timeout + try: + await asyncio.wait_for(queue.join(), timeout=5.0) + except asyncio.TimeoutError: + logger.warning( + f"{queue_name}: Timeout waiting for queue to empty during shutdown" + ) + + # Cancel worker tasks + for task in list(tasks): + if not task.done(): + task.cancel() + + # Wait for all tasks to complete + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + + # Cancel health check task + if worker_health_check_task and not worker_health_check_task.done(): + worker_health_check_task.cancel() + try: + await worker_health_check_task + except asyncio.CancelledError: + pass + + logger.info(f"{queue_name}: Priority queue workers shutdown complete") + + @wraps(func) + async def wait_func( + *args, _priority=10, _timeout=None, _queue_timeout=None, **kwargs + ): + """ + Execute function with enhanced priority-based concurrency control and timeout handling + + Args: + *args: Positional arguments passed to the function + _priority: Call priority (lower values have higher priority) + _timeout: Maximum time to wait for completion (in seconds, none means determinded by max_execution_timeout of the queue) + _queue_timeout: Maximum time to wait for entering the queue (in seconds) + **kwargs: Keyword arguments passed to the function + + Returns: + The result of the function call + + Raises: + TimeoutError: If the function call times out at any level + QueueFullError: If the queue is full and waiting times out + Any exception raised by the decorated function + """ + await ensure_workers() + + # Generate unique task ID + task_id = f"{id(asyncio.current_task())}_{asyncio.get_event_loop().time()}" + future = asyncio.Future() + + # Create task state + task_state = TaskState( + future=future, start_time=asyncio.get_event_loop().time() + ) + + try: + # Register task state + async with task_states_lock: + task_states[task_id] = task_state + + active_futures.add(future) + + # Get counter for FIFO ordering + nonlocal counter + async with initialization_lock: + current_count = counter + counter += 1 + + # Queue the task with timeout handling + try: + if _queue_timeout is not None: + await asyncio.wait_for( + queue.put( + (_priority, current_count, task_id, args, kwargs) + ), + timeout=_queue_timeout, + ) + else: + await queue.put( + (_priority, current_count, task_id, args, kwargs) + ) + except asyncio.TimeoutError: + raise QueueFullError( + f"{queue_name}: Queue full, timeout after {_queue_timeout} seconds" + ) + except Exception as e: + # Clean up on queue error + if not future.done(): + future.set_exception(e) + raise + + # Wait for result with timeout handling + try: + if _timeout is not None: + return await asyncio.wait_for(future, _timeout) + else: + return await future + except asyncio.TimeoutError: + # This is user-level timeout (asyncio.wait_for caused) + # Mark cancellation request + async with task_states_lock: + if task_id in task_states: + task_states[task_id].cancellation_requested = True + + # Cancel future + if not future.done(): + future.cancel() + + # Wait for worker cleanup with timeout + cleanup_start = asyncio.get_event_loop().time() + while ( + task_id in task_states + and asyncio.get_event_loop().time() - cleanup_start + < cleanup_timeout + ): + await asyncio.sleep(0.1) + + raise TimeoutError( + f"{queue_name}: User timeout after {_timeout} seconds" + ) + except WorkerTimeoutError as e: + # This is Worker-level timeout, directly propagate exception information + raise TimeoutError(f"{queue_name}: {str(e)}") + except HealthCheckTimeoutError as e: + # This is Health Check-level timeout, directly propagate exception information + raise TimeoutError(f"{queue_name}: {str(e)}") + + finally: + # Ensure cleanup + active_futures.discard(future) + async with task_states_lock: + task_states.pop(task_id, None) + + # Add shutdown method to decorated function + wait_func.shutdown = shutdown + + return wait_func + + return final_decro + + +def wrap_embedding_func_with_attrs(**kwargs): + """Decorator to add embedding dimension and token limit attributes to embedding functions. + + This decorator wraps an async embedding function and returns an EmbeddingFunc instance + that automatically handles dimension parameter injection and attribute management. + + WARNING: DO NOT apply this decorator to wrapper functions that call other + decorated embedding functions. This will cause double decoration and parameter + injection conflicts. + + Correct usage patterns: + + 1. Direct decoration: + ```python + @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192, model_name="my_embedding_model") + async def my_embed(texts, embedding_dim=None): + # Direct implementation + return embeddings + ``` + 2. Double decoration: + ```python + @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192, model_name="my_embedding_model") + @retry(...) + async def my_embed(texts, ...): + # Base implementation + pass + + @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=4096, model_name="another_embedding_model") + # Note: No @retry here! + async def my_new_embed(texts, ...): + # CRITICAL: Call .func to access unwrapped function + return await my_embed.func(texts, ...) # ✅ Correct + # return await my_embed(texts, ...) # ❌ Wrong - double decoration! + ``` + + The decorated function becomes an EmbeddingFunc instance with: + - embedding_dim: The embedding dimension + - max_token_size: Maximum token limit (optional) + - model_name: Model name (optional) + - func: The original unwrapped function (access via .func) + - __call__: Wrapper that injects embedding_dim parameter + + Args: + embedding_dim: The dimension of embedding vectors + max_token_size: Maximum number of tokens (optional) + send_dimensions: Whether to pass embedding_dim as a keyword argument (for models with configurable embedding dimensions). + + Returns: + A decorator that wraps the function as an EmbeddingFunc instance + """ + + def final_decro(func) -> EmbeddingFunc: + new_func = EmbeddingFunc(**kwargs, func=func) + return new_func + + return final_decro + + +def load_json(file_name): + if not os.path.exists(file_name): + return None + with open(file_name, encoding="utf-8-sig") as f: + return json.load(f) + + +def _sanitize_string_for_json(text: str) -> str: + """Remove characters that cannot be encoded in UTF-8 for JSON serialization. + + Uses regex for optimal performance with zero-copy optimization for clean strings. + Fast detection path for clean strings (99% of cases) with efficient removal for dirty strings. + + Args: + text: String to sanitize + + Returns: + Original string if clean (zero-copy), sanitized string if dirty + """ + if not text: + return text + + # Fast path: Check if sanitization is needed using C-level regex search + if not _SURROGATE_PATTERN.search(text): + return text # Zero-copy for clean strings - most common case + + # Slow path: Remove problematic characters using C-level regex substitution + return _SURROGATE_PATTERN.sub("", text) + + +class SanitizingJSONEncoder(json.JSONEncoder): + """ + Custom JSON encoder that sanitizes data during serialization. + + This encoder cleans strings during the encoding process without creating + a full copy of the data structure, making it memory-efficient for large datasets. + """ + + def encode(self, o): + """Override encode method to handle simple string cases""" + if isinstance(o, str): + return json.encoder.encode_basestring(_sanitize_string_for_json(o)) + return super().encode(o) + + def iterencode(self, o, _one_shot=False): + """ + Override iterencode to sanitize strings during serialization. + This is the core method that handles complex nested structures. + """ + # Preprocess: sanitize all strings in the object + sanitized = self._sanitize_for_encoding(o) + + # Call parent's iterencode with sanitized data + for chunk in super().iterencode(sanitized, _one_shot): + yield chunk + + def _sanitize_for_encoding(self, obj): + """ + Recursively sanitize strings in an object. + Creates new objects only when necessary to avoid deep copies. + + Args: + obj: Object to sanitize + + Returns: + Sanitized object with cleaned strings + """ + if isinstance(obj, str): + return _sanitize_string_for_json(obj) + + elif isinstance(obj, dict): + # Create new dict with sanitized keys and values + new_dict = {} + for k, v in obj.items(): + clean_k = _sanitize_string_for_json(k) if isinstance(k, str) else k + clean_v = self._sanitize_for_encoding(v) + new_dict[clean_k] = clean_v + return new_dict + + elif isinstance(obj, (list, tuple)): + # Sanitize list/tuple elements + cleaned = [self._sanitize_for_encoding(item) for item in obj] + return type(obj)(cleaned) if isinstance(obj, tuple) else cleaned + + else: + # Numbers, booleans, None, etc. remain unchanged + return obj + + +def write_json(json_obj, file_name): + """ + Write JSON data to file with optimized sanitization strategy. + + This function uses a two-stage approach: + 1. Fast path: Try direct serialization (works for clean data ~99% of time) + 2. Slow path: Use custom encoder that sanitizes during serialization + + The custom encoder approach avoids creating a deep copy of the data, + making it memory-efficient. When sanitization occurs, the caller should + reload the cleaned data from the file to update shared memory. + + Args: + json_obj: Object to serialize (may be a shallow copy from shared memory) + file_name: Output file path + + Returns: + bool: True if sanitization was applied (caller should reload data), + False if direct write succeeded (no reload needed) + """ + try: + # Strategy 1: Fast path - try direct serialization + with open(file_name, "w", encoding="utf-8") as f: + json.dump(json_obj, f, indent=2, ensure_ascii=False) + return False # No sanitization needed, no reload required + + except (UnicodeEncodeError, UnicodeDecodeError) as e: + logger.debug(f"Direct JSON write failed, using sanitizing encoder: {e}") + + # Strategy 2: Use custom encoder (sanitizes during serialization, zero memory copy) + with open(file_name, "w", encoding="utf-8") as f: + json.dump(json_obj, f, indent=2, ensure_ascii=False, cls=SanitizingJSONEncoder) + + logger.info(f"JSON sanitization applied during write: {file_name}") + return True # Sanitization applied, reload recommended + + +class TokenizerInterface(Protocol): + """ + Defines the interface for a tokenizer, requiring encode and decode methods. + """ + + def encode(self, content: str) -> List[int]: + """Encodes a string into a list of tokens.""" + ... + + def decode(self, tokens: List[int]) -> str: + """Decodes a list of tokens into a string.""" + ... + + +class Tokenizer: + """ + A wrapper around a tokenizer to provide a consistent interface for encoding and decoding. + """ + + def __init__(self, model_name: str, tokenizer: TokenizerInterface): + """ + Initializes the Tokenizer with a tokenizer model name and a tokenizer instance. + + Args: + model_name: The associated model name for the tokenizer. + tokenizer: An instance of a class implementing the TokenizerInterface. + """ + self.model_name: str = model_name + self.tokenizer: TokenizerInterface = tokenizer + + def encode(self, content: str) -> List[int]: + """ + Encodes a string into a list of tokens using the underlying tokenizer. + + Args: + content: The string to encode. + + Returns: + A list of integer tokens. + """ + return self.tokenizer.encode(content) + + def decode(self, tokens: List[int]) -> str: + """ + Decodes a list of tokens into a string using the underlying tokenizer. + + Args: + tokens: A list of integer tokens to decode. + + Returns: + The decoded string. + """ + return self.tokenizer.decode(tokens) + + +class TiktokenTokenizer(Tokenizer): + """ + A Tokenizer implementation using the tiktoken library. + """ + + def __init__(self, model_name: str = "gpt-4o-mini"): + """ + Initializes the TiktokenTokenizer with a specified model name. + + Args: + model_name: The model name for the tiktoken tokenizer to use. Defaults to "gpt-4o-mini". + + Raises: + ImportError: If tiktoken is not installed. + ValueError: If the model_name is invalid. + """ + try: + import tiktoken + except ImportError: + raise ImportError( + "tiktoken is not installed. Please install it with `pip install tiktoken` or define custom `tokenizer_func`." + ) + + try: + tokenizer = tiktoken.encoding_for_model(model_name) + super().__init__(model_name=model_name, tokenizer=tokenizer) + except KeyError: + raise ValueError(f"Invalid model_name: {model_name}.") + + +def pack_user_ass_to_openai_messages(*args: str): + roles = ["user", "assistant"] + return [ + {"role": roles[i % 2], "content": content} for i, content in enumerate(args) + ] + + +def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]: + """Split a string by multiple markers""" + if not markers: + return [content] + content = content if content is not None else "" + results = re.split("|".join(re.escape(marker) for marker in markers), content) + return [r.strip() for r in results if r.strip()] + + +def is_float_regex(value: str) -> bool: + return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value)) + + +def truncate_list_by_token_size( + list_data: list[Any], + key: Callable[[Any], str], + max_token_size: int, + tokenizer: Tokenizer, +) -> list[int]: + """Truncate a list of data by token size""" + if max_token_size <= 0: + return [] + tokens = 0 + for i, data in enumerate(list_data): + tokens += len(tokenizer.encode(key(data))) + if tokens > max_token_size: + return list_data[:i] + return list_data + + +def cosine_similarity(v1, v2): + """Calculate cosine similarity between two vectors""" + dot_product = np.dot(v1, v2) + norm1 = np.linalg.norm(v1) + norm2 = np.linalg.norm(v2) + return dot_product / (norm1 * norm2) + + +async def handle_cache( + hashing_kv, + args_hash, + prompt, + mode="default", + cache_type="unknown", +) -> tuple[str, int] | None: + """Generic cache handling function with flattened cache keys + + Returns: + tuple[str, int] | None: (content, create_time) if cache hit, None if cache miss + """ + if hashing_kv is None: + return None + + if mode != "default": # handle cache for all type of query + if not hashing_kv.global_config.get("enable_llm_cache"): + return None + else: # handle cache for entity extraction + if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"): + return None + + # Use flattened cache key format: {mode}:{cache_type}:{hash} + flattened_key = generate_cache_key(mode, cache_type, args_hash) + cache_entry = await hashing_kv.get_by_id(flattened_key) + if cache_entry: + logger.debug(f"Flattened cache hit(key:{flattened_key})") + content = cache_entry["return"] + timestamp = cache_entry.get("create_time", 0) + return content, timestamp + + logger.debug(f"Cache missed(mode:{mode} type:{cache_type})") + return None + + +@dataclass +class CacheData: + args_hash: str + content: str + prompt: str + mode: str = "default" + cache_type: str = "query" + chunk_id: str | None = None + queryparam: dict | None = None + + +async def save_to_cache(hashing_kv, cache_data: CacheData): + """Save data to cache using flattened key structure. + + Args: + hashing_kv: The key-value storage for caching + cache_data: The cache data to save + """ + # Skip if storage is None or content is a streaming response + if hashing_kv is None or not cache_data.content: + return + + # If content is a streaming response, don't cache it + if hasattr(cache_data.content, "__aiter__"): + logger.debug("Streaming response detected, skipping cache") + return + + # Use flattened cache key format: {mode}:{cache_type}:{hash} + flattened_key = generate_cache_key( + cache_data.mode, cache_data.cache_type, cache_data.args_hash + ) + + # Check if we already have identical content cached + existing_cache = await hashing_kv.get_by_id(flattened_key) + if existing_cache: + existing_content = existing_cache.get("return") + if existing_content == cache_data.content: + logger.warning( + f"Cache duplication detected for {flattened_key}, skipping update" + ) + return + + # Create cache entry with flattened structure + cache_entry = { + "return": cache_data.content, + "cache_type": cache_data.cache_type, + "chunk_id": cache_data.chunk_id if cache_data.chunk_id is not None else None, + "original_prompt": cache_data.prompt, + "queryparam": cache_data.queryparam + if cache_data.queryparam is not None + else None, + } + + logger.info(f" == LLM cache == saving: {flattened_key}") + + # Save using flattened key + await hashing_kv.upsert({flattened_key: cache_entry}) + + +def safe_unicode_decode(content): + # Regular expression to find all Unicode escape sequences of the form \uXXXX + unicode_escape_pattern = re.compile(r"\\u([0-9a-fA-F]{4})") + + # Function to replace the Unicode escape with the actual character + def replace_unicode_escape(match): + # Convert the matched hexadecimal value into the actual Unicode character + return chr(int(match.group(1), 16)) + + # Perform the substitution + decoded_content = unicode_escape_pattern.sub( + replace_unicode_escape, content.decode("utf-8") + ) + + return decoded_content + + +def exists_func(obj, func_name: str) -> bool: + """Check if a function exists in an object or not. + :param obj: + :param func_name: + :return: True / False + """ + if callable(getattr(obj, func_name, None)): + return True + else: + return False + + +def always_get_an_event_loop() -> asyncio.AbstractEventLoop: + """ + Ensure that there is always an event loop available. + + This function tries to get the current event loop. If the current event loop is closed or does not exist, + it creates a new event loop and sets it as the current event loop. + + Returns: + asyncio.AbstractEventLoop: The current or newly created event loop. + """ + try: + # Try to get the current event loop + current_loop = asyncio.get_event_loop() + if current_loop.is_closed(): + raise RuntimeError("Event loop is closed.") + return current_loop + + except RuntimeError: + # If no event loop exists or it is closed, create a new one + logger.info("Creating a new event loop in main thread.") + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + return new_loop + + +async def aexport_data( + chunk_entity_relation_graph, + entities_vdb, + relationships_vdb, + output_path: str, + file_format: str = "csv", + include_vector_data: bool = False, +) -> None: + """ + Asynchronously exports all entities, relations, and relationships to various formats. + + Args: + chunk_entity_relation_graph: Graph storage instance for entities and relations + entities_vdb: Vector database storage for entities + relationships_vdb: Vector database storage for relationships + output_path: The path to the output file (including extension). + file_format: Output format - "csv", "excel", "md", "txt". + - csv: Comma-separated values file + - excel: Microsoft Excel file with multiple sheets + - md: Markdown tables + - txt: Plain text formatted output + include_vector_data: Whether to include data from the vector database. + """ + # Collect data + entities_data = [] + relations_data = [] + relationships_data = [] + + # --- Entities --- + all_entities = await chunk_entity_relation_graph.get_all_labels() + for entity_name in all_entities: + # Get entity information from graph + node_data = await chunk_entity_relation_graph.get_node(entity_name) + source_id = node_data.get("source_id") if node_data else None + + entity_info = { + "graph_data": node_data, + "source_id": source_id, + } + + # Optional: Get vector database information + if include_vector_data: + entity_id = compute_mdhash_id(entity_name, prefix="ent-") + vector_data = await entities_vdb.get_by_id(entity_id) + entity_info["vector_data"] = vector_data + + entity_row = { + "entity_name": entity_name, + "source_id": source_id, + "graph_data": str( + entity_info["graph_data"] + ), # Convert to string to ensure compatibility + } + if include_vector_data and "vector_data" in entity_info: + entity_row["vector_data"] = str(entity_info["vector_data"]) + entities_data.append(entity_row) + + # --- Relations --- + for src_entity in all_entities: + for tgt_entity in all_entities: + if src_entity == tgt_entity: + continue + + edge_exists = await chunk_entity_relation_graph.has_edge( + src_entity, tgt_entity + ) + if edge_exists: + # Get edge information from graph + edge_data = await chunk_entity_relation_graph.get_edge( + src_entity, tgt_entity + ) + source_id = edge_data.get("source_id") if edge_data else None + + relation_info = { + "graph_data": edge_data, + "source_id": source_id, + } + + # Optional: Get vector database information + if include_vector_data: + rel_id = compute_mdhash_id(src_entity + tgt_entity, prefix="rel-") + vector_data = await relationships_vdb.get_by_id(rel_id) + relation_info["vector_data"] = vector_data + + relation_row = { + "src_entity": src_entity, + "tgt_entity": tgt_entity, + "source_id": relation_info["source_id"], + "graph_data": str(relation_info["graph_data"]), # Convert to string + } + if include_vector_data and "vector_data" in relation_info: + relation_row["vector_data"] = str(relation_info["vector_data"]) + relations_data.append(relation_row) + + # --- Relationships (from VectorDB) --- + all_relationships = await relationships_vdb.client_storage + for rel in all_relationships["data"]: + relationships_data.append( + { + "relationship_id": rel["__id__"], + "data": str(rel), # Convert to string for compatibility + } + ) + + # Export based on format + if file_format == "csv": + # CSV export + with open(output_path, "w", newline="", encoding="utf-8") as csvfile: + # Entities + if entities_data: + csvfile.write("# ENTITIES\n") + writer = csv.DictWriter(csvfile, fieldnames=entities_data[0].keys()) + writer.writeheader() + writer.writerows(entities_data) + csvfile.write("\n\n") + + # Relations + if relations_data: + csvfile.write("# RELATIONS\n") + writer = csv.DictWriter(csvfile, fieldnames=relations_data[0].keys()) + writer.writeheader() + writer.writerows(relations_data) + csvfile.write("\n\n") + + # Relationships + if relationships_data: + csvfile.write("# RELATIONSHIPS\n") + writer = csv.DictWriter( + csvfile, fieldnames=relationships_data[0].keys() + ) + writer.writeheader() + writer.writerows(relationships_data) + + elif file_format == "excel": + # Excel export + import pandas as pd + + entities_df = pd.DataFrame(entities_data) if entities_data else pd.DataFrame() + relations_df = ( + pd.DataFrame(relations_data) if relations_data else pd.DataFrame() + ) + relationships_df = ( + pd.DataFrame(relationships_data) if relationships_data else pd.DataFrame() + ) + + with pd.ExcelWriter(output_path, engine="xlsxwriter") as writer: + if not entities_df.empty: + entities_df.to_excel(writer, sheet_name="Entities", index=False) + if not relations_df.empty: + relations_df.to_excel(writer, sheet_name="Relations", index=False) + if not relationships_df.empty: + relationships_df.to_excel( + writer, sheet_name="Relationships", index=False + ) + + elif file_format == "md": + # Markdown export + with open(output_path, "w", encoding="utf-8") as mdfile: + mdfile.write("# LightRAG Data Export\n\n") + + # Entities + mdfile.write("## Entities\n\n") + if entities_data: + # Write header + mdfile.write("| " + " | ".join(entities_data[0].keys()) + " |\n") + mdfile.write( + "| " + " | ".join(["---"] * len(entities_data[0].keys())) + " |\n" + ) + + # Write rows + for entity in entities_data: + mdfile.write( + "| " + " | ".join(str(v) for v in entity.values()) + " |\n" + ) + mdfile.write("\n\n") + else: + mdfile.write("*No entity data available*\n\n") + + # Relations + mdfile.write("## Relations\n\n") + if relations_data: + # Write header + mdfile.write("| " + " | ".join(relations_data[0].keys()) + " |\n") + mdfile.write( + "| " + " | ".join(["---"] * len(relations_data[0].keys())) + " |\n" + ) + + # Write rows + for relation in relations_data: + mdfile.write( + "| " + " | ".join(str(v) for v in relation.values()) + " |\n" + ) + mdfile.write("\n\n") + else: + mdfile.write("*No relation data available*\n\n") + + # Relationships + mdfile.write("## Relationships\n\n") + if relationships_data: + # Write header + mdfile.write("| " + " | ".join(relationships_data[0].keys()) + " |\n") + mdfile.write( + "| " + + " | ".join(["---"] * len(relationships_data[0].keys())) + + " |\n" + ) + + # Write rows + for relationship in relationships_data: + mdfile.write( + "| " + + " | ".join(str(v) for v in relationship.values()) + + " |\n" + ) + else: + mdfile.write("*No relationship data available*\n\n") + + elif file_format == "txt": + # Plain text export + with open(output_path, "w", encoding="utf-8") as txtfile: + txtfile.write("LIGHTRAG DATA EXPORT\n") + txtfile.write("=" * 80 + "\n\n") + + # Entities + txtfile.write("ENTITIES\n") + txtfile.write("-" * 80 + "\n") + if entities_data: + # Create fixed width columns + col_widths = { + k: max(len(k), max(len(str(e[k])) for e in entities_data)) + for k in entities_data[0] + } + header = " ".join(k.ljust(col_widths[k]) for k in entities_data[0]) + txtfile.write(header + "\n") + txtfile.write("-" * len(header) + "\n") + + # Write rows + for entity in entities_data: + row = " ".join( + str(v).ljust(col_widths[k]) for k, v in entity.items() + ) + txtfile.write(row + "\n") + txtfile.write("\n\n") + else: + txtfile.write("No entity data available\n\n") + + # Relations + txtfile.write("RELATIONS\n") + txtfile.write("-" * 80 + "\n") + if relations_data: + # Create fixed width columns + col_widths = { + k: max(len(k), max(len(str(r[k])) for r in relations_data)) + for k in relations_data[0] + } + header = " ".join(k.ljust(col_widths[k]) for k in relations_data[0]) + txtfile.write(header + "\n") + txtfile.write("-" * len(header) + "\n") + + # Write rows + for relation in relations_data: + row = " ".join( + str(v).ljust(col_widths[k]) for k, v in relation.items() + ) + txtfile.write(row + "\n") + txtfile.write("\n\n") + else: + txtfile.write("No relation data available\n\n") + + # Relationships + txtfile.write("RELATIONSHIPS\n") + txtfile.write("-" * 80 + "\n") + if relationships_data: + # Create fixed width columns + col_widths = { + k: max(len(k), max(len(str(r[k])) for r in relationships_data)) + for k in relationships_data[0] + } + header = " ".join( + k.ljust(col_widths[k]) for k in relationships_data[0] + ) + txtfile.write(header + "\n") + txtfile.write("-" * len(header) + "\n") + + # Write rows + for relationship in relationships_data: + row = " ".join( + str(v).ljust(col_widths[k]) for k, v in relationship.items() + ) + txtfile.write(row + "\n") + else: + txtfile.write("No relationship data available\n\n") + + else: + raise ValueError( + f"Unsupported file format: {file_format}. Choose from: csv, excel, md, txt" + ) + if file_format is not None: + print(f"Data exported to: {output_path} with format: {file_format}") + else: + print("Data displayed as table format") + + +def export_data( + chunk_entity_relation_graph, + entities_vdb, + relationships_vdb, + output_path: str, + file_format: str = "csv", + include_vector_data: bool = False, +) -> None: + """ + Synchronously exports all entities, relations, and relationships to various formats. + + Args: + chunk_entity_relation_graph: Graph storage instance for entities and relations + entities_vdb: Vector database storage for entities + relationships_vdb: Vector database storage for relationships + output_path: The path to the output file (including extension). + file_format: Output format - "csv", "excel", "md", "txt". + - csv: Comma-separated values file + - excel: Microsoft Excel file with multiple sheets + - md: Markdown tables + - txt: Plain text formatted output + include_vector_data: Whether to include data from the vector database. + """ + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + loop.run_until_complete( + aexport_data( + chunk_entity_relation_graph, + entities_vdb, + relationships_vdb, + output_path, + file_format, + include_vector_data, + ) + ) + + +def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any]: + """Lazily import a class from an external module based on the package of the caller.""" + # Get the caller's module and package + import inspect + + caller_frame = inspect.currentframe().f_back + module = inspect.getmodule(caller_frame) + package = module.__package__ if module else None + + def import_class(*args: Any, **kwargs: Any): + import importlib + + module = importlib.import_module(module_name, package=package) + cls = getattr(module, class_name) + return cls(*args, **kwargs) + + return import_class + + +async def update_chunk_cache_list( + chunk_id: str, + text_chunks_storage: "BaseKVStorage", + cache_keys: list[str], + cache_scenario: str = "batch_update", +) -> None: + """Update chunk's llm_cache_list with the given cache keys + + Args: + chunk_id: Chunk identifier + text_chunks_storage: Text chunks storage instance + cache_keys: List of cache keys to add to the list + cache_scenario: Description of the cache scenario for logging + """ + if not cache_keys: + return + + try: + chunk_data = await text_chunks_storage.get_by_id(chunk_id) + if chunk_data: + # Ensure llm_cache_list exists + if "llm_cache_list" not in chunk_data: + chunk_data["llm_cache_list"] = [] + + # Add cache keys to the list if not already present + existing_keys = set(chunk_data["llm_cache_list"]) + new_keys = [key for key in cache_keys if key not in existing_keys] + + if new_keys: + chunk_data["llm_cache_list"].extend(new_keys) + + # Update the chunk in storage + await text_chunks_storage.upsert({chunk_id: chunk_data}) + logger.debug( + f"Updated chunk {chunk_id} with {len(new_keys)} cache keys ({cache_scenario})" + ) + except Exception as e: + logger.warning( + f"Failed to update chunk {chunk_id} with cache references on {cache_scenario}: {e}" + ) + + +def remove_think_tags(text: str) -> str: + """Remove ... tags from the text + Remove orphon ... tags from the text also""" + return re.sub( + r"^(.*?|.*)", "", text, flags=re.DOTALL + ).strip() + + +async def use_llm_func_with_cache( + user_prompt: str, + use_llm_func: callable, + llm_response_cache: "BaseKVStorage | None" = None, + system_prompt: str | None = None, + max_tokens: int = None, + history_messages: list[dict[str, str]] = None, + cache_type: str = "extract", + chunk_id: str | None = None, + cache_keys_collector: list = None, +) -> tuple[str, int]: + """Call LLM function with cache support and text sanitization + + If cache is available and enabled (determined by handle_cache based on mode), + retrieve result from cache; otherwise call LLM function and save result to cache. + + This function applies text sanitization to prevent UTF-8 encoding errors for all LLM providers. + + Args: + input_text: Input text to send to LLM + use_llm_func: LLM function with higher priority + llm_response_cache: Cache storage instance + max_tokens: Maximum tokens for generation + history_messages: History messages list + cache_type: Type of cache + chunk_id: Chunk identifier to store in cache + text_chunks_storage: Text chunks storage to update llm_cache_list + cache_keys_collector: Optional list to collect cache keys for batch processing + + Returns: + tuple[str, int]: (LLM response text, timestamp) + - For cache hits: (content, cache_create_time) + - For cache misses: (content, current_timestamp) + """ + # Sanitize input text to prevent UTF-8 encoding errors for all LLM providers + safe_user_prompt = sanitize_text_for_encoding(user_prompt) + safe_system_prompt = ( + sanitize_text_for_encoding(system_prompt) if system_prompt else None + ) + + # Sanitize history messages if provided + safe_history_messages = None + if history_messages: + safe_history_messages = [] + for i, msg in enumerate(history_messages): + safe_msg = msg.copy() + if "content" in safe_msg: + safe_msg["content"] = sanitize_text_for_encoding(safe_msg["content"]) + safe_history_messages.append(safe_msg) + history = json.dumps(safe_history_messages, ensure_ascii=False) + else: + history = None + + if llm_response_cache: + prompt_parts = [] + if safe_user_prompt: + prompt_parts.append(safe_user_prompt) + if safe_system_prompt: + prompt_parts.append(safe_system_prompt) + if history: + prompt_parts.append(history) + _prompt = "\n".join(prompt_parts) + + arg_hash = compute_args_hash(_prompt) + # Generate cache key for this LLM call + cache_key = generate_cache_key("default", cache_type, arg_hash) + + cached_result = await handle_cache( + llm_response_cache, + arg_hash, + _prompt, + "default", + cache_type=cache_type, + ) + if cached_result: + content, timestamp = cached_result + logger.debug(f"Found cache for {arg_hash}") + statistic_data["llm_cache"] += 1 + + # Add cache key to collector if provided + if cache_keys_collector is not None: + cache_keys_collector.append(cache_key) + + return content, timestamp + statistic_data["llm_call"] += 1 + + # Call LLM with sanitized input + kwargs = {} + if safe_history_messages: + kwargs["history_messages"] = safe_history_messages + if max_tokens is not None: + kwargs["max_tokens"] = max_tokens + + res: str = await use_llm_func( + safe_user_prompt, system_prompt=safe_system_prompt, **kwargs + ) + + res = remove_think_tags(res) + + # Generate timestamp for cache miss (LLM call completion time) + current_timestamp = int(time.time()) + + if llm_response_cache.global_config.get("enable_llm_cache_for_entity_extract"): + await save_to_cache( + llm_response_cache, + CacheData( + args_hash=arg_hash, + content=res, + prompt=_prompt, + cache_type=cache_type, + chunk_id=chunk_id, + ), + ) + + # Add cache key to collector if provided + if cache_keys_collector is not None: + cache_keys_collector.append(cache_key) + + return res, current_timestamp + + # When cache is disabled, directly call LLM with sanitized input + kwargs = {} + if safe_history_messages: + kwargs["history_messages"] = safe_history_messages + if max_tokens is not None: + kwargs["max_tokens"] = max_tokens + + try: + res = await use_llm_func( + safe_user_prompt, system_prompt=safe_system_prompt, **kwargs + ) + except Exception as e: + # Add [LLM func] prefix to error message + error_msg = f"[LLM func] {str(e)}" + # Re-raise with the same exception type but modified message + raise type(e)(error_msg) from e + + # Generate timestamp for non-cached LLM call + current_timestamp = int(time.time()) + return remove_think_tags(res), current_timestamp + + +def get_content_summary(content: str, max_length: int = 250) -> str: + """Get summary of document content + + Args: + content: Original document content + max_length: Maximum length of summary + + Returns: + Truncated content with ellipsis if needed + """ + content = content.strip() + if len(content) <= max_length: + return content + return content[:max_length] + "..." + + +def sanitize_and_normalize_extracted_text( + input_text: str, remove_inner_quotes=False +) -> str: + """Santitize and normalize extracted text + Args: + input_text: text string to be processed + is_name: whether the input text is a entity or relation name + + Returns: + Santitized and normalized text string + """ + safe_input_text = sanitize_text_for_encoding(input_text) + if safe_input_text: + normalized_text = normalize_extracted_info( + safe_input_text, remove_inner_quotes=remove_inner_quotes + ) + return normalized_text + return "" + + +def normalize_extracted_info(name: str, remove_inner_quotes=False) -> str: + """Normalize entity/relation names and description with the following rules: + - Clean HTML tags (paragraph and line break tags) + - Convert Chinese symbols to English symbols + - Remove spaces between Chinese characters + - Remove spaces between Chinese characters and English letters/numbers + - Preserve spaces within English text and numbers + - Replace Chinese parentheses with English parentheses + - Replace Chinese dash with English dash + - Remove English quotation marks from the beginning and end of the text + - Remove English quotation marks in and around chinese + - Remove Chinese quotation marks + - Filter out short numeric-only text (length < 3 and only digits/dots) + - remove_inner_quotes = True + remove Chinese quotes + remove English quotes in and around chinese + Convert non-breaking spaces to regular spaces + Convert narrow non-breaking spaces after non-digits to regular spaces + + Args: + name: Entity name to normalize + is_entity: Whether this is an entity name (affects quote handling) + + Returns: + Normalized entity name + """ + # Clean HTML tags - remove paragraph and line break tags + name = re.sub(r"||

", "", name, flags=re.IGNORECASE) + name = re.sub(r"||
", "", name, flags=re.IGNORECASE) + + # Chinese full-width letters to half-width (A-Z, a-z) + name = name.translate( + str.maketrans( + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz", + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz", + ) + ) + + # Chinese full-width numbers to half-width + name = name.translate(str.maketrans("0123456789", "0123456789")) + + # Chinese full-width symbols to half-width + name = name.replace("-", "-") # Chinese minus + name = name.replace("+", "+") # Chinese plus + name = name.replace("/", "/") # Chinese slash + name = name.replace("*", "*") # Chinese asterisk + + # Replace Chinese parentheses with English parentheses + name = name.replace("(", "(").replace(")", ")") + + # Replace Chinese dash with English dash (additional patterns) + name = name.replace("—", "-").replace("-", "-") + + # Chinese full-width space to regular space (after other replacements) + name = name.replace(" ", " ") + + # Use regex to remove spaces between Chinese characters + # Regex explanation: + # (?<=[\u4e00-\u9fa5]): Positive lookbehind for Chinese character + # \s+: One or more whitespace characters + # (?=[\u4e00-\u9fa5]): Positive lookahead for Chinese character + name = re.sub(r"(?<=[\u4e00-\u9fa5])\s+(?=[\u4e00-\u9fa5])", "", name) + + # Remove spaces between Chinese and English/numbers/symbols + name = re.sub( + r"(?<=[\u4e00-\u9fa5])\s+(?=[a-zA-Z0-9\(\)\[\]@#$%!&\*\-=+_])", "", name + ) + name = re.sub( + r"(?<=[a-zA-Z0-9\(\)\[\]@#$%!&\*\-=+_])\s+(?=[\u4e00-\u9fa5])", "", name + ) + + # Remove outer quotes + if len(name) >= 2: + # Handle double quotes + if name.startswith('"') and name.endswith('"'): + inner_content = name[1:-1] + if '"' not in inner_content: # No double quotes inside + name = inner_content + + # Handle single quotes + if name.startswith("'") and name.endswith("'"): + inner_content = name[1:-1] + if "'" not in inner_content: # No single quotes inside + name = inner_content + + # Handle Chinese-style double quotes + if name.startswith("“") and name.endswith("”"): + inner_content = name[1:-1] + if "“" not in inner_content and "”" not in inner_content: + name = inner_content + if name.startswith("‘") and name.endswith("’"): + inner_content = name[1:-1] + if "‘" not in inner_content and "’" not in inner_content: + name = inner_content + + # Handle Chinese-style book title mark + if name.startswith("《") and name.endswith("》"): + inner_content = name[1:-1] + if "《" not in inner_content and "》" not in inner_content: + name = inner_content + + if remove_inner_quotes: + # Remove Chinese quotes + name = name.replace("“", "").replace("”", "").replace("‘", "").replace("’", "") + # Remove English queotes in and around chinese + name = re.sub(r"['\"]+(?=[\u4e00-\u9fa5])", "", name) + name = re.sub(r"(?<=[\u4e00-\u9fa5])['\"]+", "", name) + # Convert non-breaking space to regular space + name = name.replace("\u00a0", " ") + # Convert narrow non-breaking space to regular space when after non-digits + name = re.sub(r"(?<=[^\d])\u202F", " ", name) + + # Remove spaces from the beginning and end of the text + name = name.strip() + + # Filter out pure numeric content with length < 3 + if len(name) < 3 and re.match(r"^[0-9]+$", name): + return "" + + def should_filter_by_dots(text): + """ + Check if the string consists only of dots and digits, with at least one dot + Filter cases include: 1.2.3, 12.3, .123, 123., 12.3., .1.23 etc. + """ + return all(c.isdigit() or c == "." for c in text) and "." in text + + if len(name) < 6 and should_filter_by_dots(name): + # Filter out mixed numeric and dot content with length < 6 + return "" + # Filter out mixed numeric and dot content with length < 6, requiring at least one dot + return "" + + return name + + +def sanitize_text_for_encoding(text: str, replacement_char: str = "") -> str: + """Sanitize text to ensure safe UTF-8 encoding by removing or replacing problematic characters. + + This function handles: + - Surrogate characters (the main cause of encoding errors) + - Other invalid Unicode sequences + - Control characters that might cause issues + - Unescape HTML escapes + - Remove control characters + - Whitespace trimming + + Args: + text: Input text to sanitize + replacement_char: Character to use for replacing invalid sequences + + Returns: + Sanitized text that can be safely encoded as UTF-8 + + Raises: + ValueError: When text contains uncleanable encoding issues that cannot be safely processed + """ + if not text: + return text + + try: + # First, strip whitespace + text = text.strip() + + # Early return if text is empty after basic cleaning + if not text: + return text + + # Try to encode/decode to catch any encoding issues early + text.encode("utf-8") + + # Remove or replace surrogate characters (U+D800 to U+DFFF) + # These are the main cause of the encoding error + sanitized = "" + for char in text: + code_point = ord(char) + # Check for surrogate characters + if 0xD800 <= code_point <= 0xDFFF: + # Replace surrogate with replacement character + sanitized += replacement_char + continue + # Check for other problematic characters + elif code_point == 0xFFFE or code_point == 0xFFFF: + # These are non-characters in Unicode + sanitized += replacement_char + continue + else: + sanitized += char + + # Additional cleanup: remove null bytes and other control characters that might cause issues + # (but preserve common whitespace like \t, \n, \r) + sanitized = re.sub( + r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]", replacement_char, sanitized + ) + + # Test final encoding to ensure it's safe + sanitized.encode("utf-8") + + # Unescape HTML escapes + sanitized = html.unescape(sanitized) + + # Remove control characters but preserve common whitespace (\t, \n, \r) + sanitized = re.sub(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F-\x9F]", "", sanitized) + + return sanitized.strip() + + except UnicodeEncodeError as e: + # Critical change: Don't return placeholder, raise exception for caller to handle + error_msg = f"Text contains uncleanable UTF-8 encoding issues: {str(e)[:100]}" + logger.error(f"Text sanitization failed: {error_msg}") + raise ValueError(error_msg) from e + + except Exception as e: + logger.error(f"Text sanitization: Unexpected error: {str(e)}") + # For other exceptions, if no encoding issues detected, return original text + try: + text.encode("utf-8") + return text + except UnicodeEncodeError: + raise ValueError( + f"Text sanitization failed with unexpected error: {str(e)}" + ) from e + + +def check_storage_env_vars(storage_name: str) -> None: + """Check if all required environment variables for storage implementation exist + + Args: + storage_name: Storage implementation name + + Raises: + ValueError: If required environment variables are missing + """ + from lightrag.kg import STORAGE_ENV_REQUIREMENTS + + required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, []) + missing_vars = [var for var in required_vars if var not in os.environ] + + if missing_vars: + raise ValueError( + f"Storage implementation '{storage_name}' requires the following " + f"environment variables: {', '.join(missing_vars)}" + ) + + +def pick_by_weighted_polling( + entities_or_relations: list[dict], + max_related_chunks: int, + min_related_chunks: int = 1, +) -> list[str]: + """ + Linear gradient weighted polling algorithm for text chunk selection. + + This algorithm ensures that entities/relations with higher importance get more text chunks, + forming a linear decreasing allocation pattern. + + Args: + entities_or_relations: List of entities or relations sorted by importance (high to low) + max_related_chunks: Expected number of text chunks for the highest importance entity/relation + min_related_chunks: Expected number of text chunks for the lowest importance entity/relation + + Returns: + List of selected text chunk IDs + """ + if not entities_or_relations: + return [] + + n = len(entities_or_relations) + if n == 1: + # Only one entity/relation, return its first max_related_chunks text chunks + entity_chunks = entities_or_relations[0].get("sorted_chunks", []) + return entity_chunks[:max_related_chunks] + + # Calculate expected text chunk count for each position (linear decrease) + expected_counts = [] + for i in range(n): + # Linear interpolation: from max_related_chunks to min_related_chunks + ratio = i / (n - 1) if n > 1 else 0 + expected = max_related_chunks - ratio * ( + max_related_chunks - min_related_chunks + ) + expected_counts.append(int(round(expected))) + + # First round allocation: allocate by expected values + selected_chunks = [] + used_counts = [] # Track number of chunks used by each entity + total_remaining = 0 # Accumulate remaining quotas + + for i, entity_rel in enumerate(entities_or_relations): + entity_chunks = entity_rel.get("sorted_chunks", []) + expected = expected_counts[i] + + # Actual allocatable count + actual = min(expected, len(entity_chunks)) + selected_chunks.extend(entity_chunks[:actual]) + used_counts.append(actual) + + # Accumulate remaining quota + remaining = expected - actual + if remaining > 0: + total_remaining += remaining + + # Second round allocation: multi-round scanning to allocate remaining quotas + for _ in range(total_remaining): + allocated = False + + # Scan entities one by one, allocate one chunk when finding unused chunks + for i, entity_rel in enumerate(entities_or_relations): + entity_chunks = entity_rel.get("sorted_chunks", []) + + # Check if there are still unused chunks + if used_counts[i] < len(entity_chunks): + # Allocate one chunk + selected_chunks.append(entity_chunks[used_counts[i]]) + used_counts[i] += 1 + allocated = True + break + + # If no chunks were allocated in this round, all entities are exhausted + if not allocated: + break + + return selected_chunks + + +async def pick_by_vector_similarity( + query: str, + text_chunks_storage: "BaseKVStorage", + chunks_vdb: "BaseVectorStorage", + num_of_chunks: int, + entity_info: list[dict[str, Any]], + embedding_func: callable, + query_embedding=None, +) -> list[str]: + """ + Vector similarity-based text chunk selection algorithm. + + This algorithm selects text chunks based on cosine similarity between + the query embedding and text chunk embeddings. + + Args: + query: User's original query string + text_chunks_storage: Text chunks storage instance + chunks_vdb: Vector database storage for chunks + num_of_chunks: Number of chunks to select + entity_info: List of entity information containing chunk IDs + embedding_func: Embedding function to compute query embedding + + Returns: + List of selected text chunk IDs sorted by similarity (highest first) + """ + logger.debug( + f"Vector similarity chunk selection: num_of_chunks={num_of_chunks}, entity_info_count={len(entity_info) if entity_info else 0}" + ) + + if not entity_info or num_of_chunks <= 0: + return [] + + # Collect all unique chunk IDs from entity info + all_chunk_ids = set() + for i, entity in enumerate(entity_info): + chunk_ids = entity.get("sorted_chunks", []) + all_chunk_ids.update(chunk_ids) + + if not all_chunk_ids: + logger.warning( + "Vector similarity chunk selection: no chunk IDs found in entity_info" + ) + return [] + + logger.debug( + f"Vector similarity chunk selection: {len(all_chunk_ids)} unique chunk IDs collected" + ) + + all_chunk_ids = list(all_chunk_ids) + + try: + # Use pre-computed query embedding if provided, otherwise compute it + if query_embedding is None: + query_embedding = await embedding_func([query]) + query_embedding = query_embedding[ + 0 + ] # Extract first embedding from batch result + logger.debug( + "Computed query embedding for vector similarity chunk selection" + ) + else: + logger.debug( + "Using pre-computed query embedding for vector similarity chunk selection" + ) + + # Get chunk embeddings from vector database + chunk_vectors = await chunks_vdb.get_vectors_by_ids(all_chunk_ids) + logger.debug( + f"Vector similarity chunk selection: {len(chunk_vectors)} chunk vectors Retrieved" + ) + + if not chunk_vectors or len(chunk_vectors) != len(all_chunk_ids): + if not chunk_vectors: + logger.warning( + "Vector similarity chunk selection: no vectors retrieved from chunks_vdb" + ) + else: + logger.warning( + f"Vector similarity chunk selection: found {len(chunk_vectors)} but expecting {len(all_chunk_ids)}" + ) + return [] + + # Calculate cosine similarities + similarities = [] + valid_vectors = 0 + for chunk_id in all_chunk_ids: + if chunk_id in chunk_vectors: + chunk_embedding = chunk_vectors[chunk_id] + try: + # Calculate cosine similarity + similarity = cosine_similarity(query_embedding, chunk_embedding) + similarities.append((chunk_id, similarity)) + valid_vectors += 1 + except Exception as e: + logger.warning( + f"Vector similarity chunk selection: failed to calculate similarity for chunk {chunk_id}: {e}" + ) + else: + logger.warning( + f"Vector similarity chunk selection: no vector found for chunk {chunk_id}" + ) + + # Sort by similarity (highest first) and select top num_of_chunks + similarities.sort(key=lambda x: x[1], reverse=True) + selected_chunks = [chunk_id for chunk_id, _ in similarities[:num_of_chunks]] + + logger.debug( + f"Vector similarity chunk selection: {len(selected_chunks)} chunks from {len(all_chunk_ids)} candidates" + ) + + return selected_chunks + + except Exception as e: + logger.error(f"[VECTOR_SIMILARITY] Error in vector similarity sorting: {e}") + import traceback + + logger.error(f"[VECTOR_SIMILARITY] Traceback: {traceback.format_exc()}") + # Fallback to simple truncation + logger.debug("[VECTOR_SIMILARITY] Falling back to simple truncation") + return all_chunk_ids[:num_of_chunks] + + +class TokenTracker: + """Track token usage for LLM calls.""" + + def __init__(self): + self.reset() + + def __enter__(self): + self.reset() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + print(self) + + def reset(self): + self.prompt_tokens = 0 + self.completion_tokens = 0 + self.total_tokens = 0 + self.call_count = 0 + + def add_usage(self, token_counts): + """Add token usage from one LLM call. + + Args: + token_counts: A dictionary containing prompt_tokens, completion_tokens, total_tokens + """ + self.prompt_tokens += token_counts.get("prompt_tokens", 0) + self.completion_tokens += token_counts.get("completion_tokens", 0) + + # If total_tokens is provided, use it directly; otherwise calculate the sum + if "total_tokens" in token_counts: + self.total_tokens += token_counts["total_tokens"] + else: + self.total_tokens += token_counts.get( + "prompt_tokens", 0 + ) + token_counts.get("completion_tokens", 0) + + self.call_count += 1 + + def get_usage(self): + """Get current usage statistics.""" + return { + "prompt_tokens": self.prompt_tokens, + "completion_tokens": self.completion_tokens, + "total_tokens": self.total_tokens, + "call_count": self.call_count, + } + + def __str__(self): + usage = self.get_usage() + return ( + f"LLM call count: {usage['call_count']}, " + f"Prompt tokens: {usage['prompt_tokens']}, " + f"Completion tokens: {usage['completion_tokens']}, " + f"Total tokens: {usage['total_tokens']}" + ) + + +async def apply_rerank_if_enabled( + query: str, + retrieved_docs: list[dict], + global_config: dict, + enable_rerank: bool = True, + top_n: int = None, +) -> list[dict]: + """ + Apply reranking to retrieved documents if rerank is enabled. + + Args: + query: The search query + retrieved_docs: List of retrieved documents + global_config: Global configuration containing rerank settings + enable_rerank: Whether to enable reranking from query parameter + top_n: Number of top documents to return after reranking + + Returns: + Reranked documents if rerank is enabled, otherwise original documents + """ + if not enable_rerank or not retrieved_docs: + return retrieved_docs + + rerank_func = global_config.get("rerank_model_func") + if not rerank_func: + logger.warning( + "Rerank is enabled but no rerank model is configured. Please set up a rerank model or set enable_rerank=False in query parameters." + ) + return retrieved_docs + + try: + # Extract document content for reranking + document_texts = [] + for doc in retrieved_docs: + # Try multiple possible content fields + content = ( + doc.get("content") + or doc.get("text") + or doc.get("chunk_content") + or doc.get("document") + or str(doc) + ) + document_texts.append(content) + + # Call the new rerank function that returns index-based results + rerank_results = await rerank_func( + query=query, + documents=document_texts, + top_n=top_n, + ) + + # Process rerank results based on return format + if rerank_results and len(rerank_results) > 0: + # Check if results are in the new index-based format + if isinstance(rerank_results[0], dict) and "index" in rerank_results[0]: + # New format: [{"index": 0, "relevance_score": 0.85}, ...] + reranked_docs = [] + for result in rerank_results: + index = result["index"] + relevance_score = result["relevance_score"] + + # Get original document and add rerank score + if 0 <= index < len(retrieved_docs): + doc = retrieved_docs[index].copy() + doc["rerank_score"] = relevance_score + reranked_docs.append(doc) + + logger.info( + f"Successfully reranked: {len(reranked_docs)} chunks from {len(retrieved_docs)} original chunks" + ) + return reranked_docs + else: + # Legacy format: assume it's already reranked documents + logger.info(f"Using legacy rerank format: {len(rerank_results)} chunks") + return rerank_results[:top_n] if top_n else rerank_results + else: + logger.warning("Rerank returned empty results, using original chunks") + return retrieved_docs + + except Exception as e: + logger.error(f"Error during reranking: {e}, using original chunks") + return retrieved_docs + + +async def process_chunks_unified( + query: str, + unique_chunks: list[dict], + query_param: "QueryParam", + global_config: dict, + source_type: str = "mixed", + chunk_token_limit: int = None, # Add parameter for dynamic token limit +) -> list[dict]: + """ + Unified processing for text chunks: deduplication, chunk_top_k limiting, reranking, and token truncation. + + Args: + query: Search query for reranking + chunks: List of text chunks to process + query_param: Query parameters containing configuration + global_config: Global configuration dictionary + source_type: Source type for logging ("vector", "entity", "relationship", "mixed") + chunk_token_limit: Dynamic token limit for chunks (if None, uses default) + + Returns: + Processed and filtered list of text chunks + """ + if not unique_chunks: + return [] + + origin_count = len(unique_chunks) + + # 1. Apply reranking if enabled and query is provided + if query_param.enable_rerank and query and unique_chunks: + rerank_top_k = query_param.chunk_top_k or len(unique_chunks) + unique_chunks = await apply_rerank_if_enabled( + query=query, + retrieved_docs=unique_chunks, + global_config=global_config, + enable_rerank=query_param.enable_rerank, + top_n=rerank_top_k, + ) + + # 2. Filter by minimum rerank score if reranking is enabled + if query_param.enable_rerank and unique_chunks: + min_rerank_score = global_config.get("min_rerank_score", 0.5) + if min_rerank_score > 0.0: + original_count = len(unique_chunks) + + # Filter chunks with score below threshold + filtered_chunks = [] + for chunk in unique_chunks: + rerank_score = chunk.get( + "rerank_score", 1.0 + ) # Default to 1.0 if no score + if rerank_score >= min_rerank_score: + filtered_chunks.append(chunk) + + unique_chunks = filtered_chunks + filtered_count = original_count - len(unique_chunks) + + if filtered_count > 0: + logger.info( + f"Rerank filtering: {len(unique_chunks)} chunks remained (min rerank score: {min_rerank_score})" + ) + if not unique_chunks: + return [] + + # 3. Apply chunk_top_k limiting if specified + if query_param.chunk_top_k is not None and query_param.chunk_top_k > 0: + if len(unique_chunks) > query_param.chunk_top_k: + unique_chunks = unique_chunks[: query_param.chunk_top_k] + logger.debug( + f"Kept chunk_top-k: {len(unique_chunks)} chunks (deduplicated original: {origin_count})" + ) + + # 4. Token-based final truncation + tokenizer = global_config.get("tokenizer") + if tokenizer and unique_chunks: + # Set default chunk_token_limit if not provided + if chunk_token_limit is None: + # Get default from query_param or global_config + chunk_token_limit = getattr( + query_param, + "max_total_tokens", + global_config.get("MAX_TOTAL_TOKENS", DEFAULT_MAX_TOTAL_TOKENS), + ) + + original_count = len(unique_chunks) + + unique_chunks = truncate_list_by_token_size( + unique_chunks, + key=lambda x: "\n".join( + json.dumps(item, ensure_ascii=False) for item in [x] + ), + max_token_size=chunk_token_limit, + tokenizer=tokenizer, + ) + + logger.debug( + f"Token truncation: {len(unique_chunks)} chunks from {original_count} " + f"(chunk available tokens: {chunk_token_limit}, source: {source_type})" + ) + + # 5. add id field to each chunk + final_chunks = [] + for i, chunk in enumerate(unique_chunks): + chunk_with_id = chunk.copy() + chunk_with_id["id"] = f"DC{i + 1}" + final_chunks.append(chunk_with_id) + + return final_chunks + + +def normalize_source_ids_limit_method(method: str | None) -> str: + """Normalize the source ID limiting strategy and fall back to default when invalid.""" + + if not method: + return DEFAULT_SOURCE_IDS_LIMIT_METHOD + + normalized = method.upper() + if normalized not in VALID_SOURCE_IDS_LIMIT_METHODS: + logger.warning( + "Unknown SOURCE_IDS_LIMIT_METHOD '%s', falling back to %s", + method, + DEFAULT_SOURCE_IDS_LIMIT_METHOD, + ) + return DEFAULT_SOURCE_IDS_LIMIT_METHOD + + return normalized + + +def merge_source_ids( + existing_ids: Iterable[str] | None, new_ids: Iterable[str] | None +) -> list[str]: + """Merge two iterables of source IDs while preserving order and removing duplicates.""" + + merged: list[str] = [] + seen: set[str] = set() + + for sequence in (existing_ids, new_ids): + if not sequence: + continue + for source_id in sequence: + if not source_id: + continue + if source_id not in seen: + seen.add(source_id) + merged.append(source_id) + + return merged + + +def apply_source_ids_limit( + source_ids: Sequence[str], + limit: int, + method: str, + *, + identifier: str | None = None, +) -> list[str]: + """Apply a limit strategy to a sequence of source IDs.""" + + if limit <= 0: + return [] + + source_ids_list = list(source_ids) + if len(source_ids_list) <= limit: + return source_ids_list + + normalized_method = normalize_source_ids_limit_method(method) + + if normalized_method == SOURCE_IDS_LIMIT_METHOD_FIFO: + truncated = source_ids_list[-limit:] + else: # IGNORE_NEW + truncated = source_ids_list[:limit] + + if identifier and len(truncated) < len(source_ids_list): + logger.debug( + "Source_id truncated: %s | %s keeping %s of %s entries", + identifier, + normalized_method, + len(truncated), + len(source_ids_list), + ) + + return truncated + + +def compute_incremental_chunk_ids( + existing_full_chunk_ids: list[str], + old_chunk_ids: list[str], + new_chunk_ids: list[str], +) -> list[str]: + """ + Compute incrementally updated chunk IDs based on changes. + + This function applies delta changes (additions and removals) to an existing + list of chunk IDs while maintaining order and ensuring deduplication. + Delta additions from new_chunk_ids are placed at the end. + + Args: + existing_full_chunk_ids: Complete list of existing chunk IDs from storage + old_chunk_ids: Previous chunk IDs from source_id (chunks being replaced) + new_chunk_ids: New chunk IDs from updated source_id (chunks being added) + + Returns: + Updated list of chunk IDs with deduplication + + Example: + >>> existing = ['chunk-1', 'chunk-2', 'chunk-3'] + >>> old = ['chunk-1', 'chunk-2'] + >>> new = ['chunk-2', 'chunk-4'] + >>> compute_incremental_chunk_ids(existing, old, new) + ['chunk-3', 'chunk-2', 'chunk-4'] + """ + # Calculate changes + chunks_to_remove = set(old_chunk_ids) - set(new_chunk_ids) + chunks_to_add = set(new_chunk_ids) - set(old_chunk_ids) + + # Apply changes to full chunk_ids + # Step 1: Remove chunks that are no longer needed + updated_chunk_ids = [ + cid for cid in existing_full_chunk_ids if cid not in chunks_to_remove + ] + + # Step 2: Add new chunks (preserving order from new_chunk_ids) + # Note: 'cid not in updated_chunk_ids' check ensures deduplication + for cid in new_chunk_ids: + if cid in chunks_to_add and cid not in updated_chunk_ids: + updated_chunk_ids.append(cid) + + return updated_chunk_ids + + +def subtract_source_ids( + source_ids: Iterable[str], + ids_to_remove: Collection[str], +) -> list[str]: + """Remove a collection of IDs from an ordered iterable while preserving order.""" + + removal_set = set(ids_to_remove) + if not removal_set: + return [source_id for source_id in source_ids if source_id] + + return [ + source_id + for source_id in source_ids + if source_id and source_id not in removal_set + ] + + +def make_relation_chunk_key(src: str, tgt: str) -> str: + """Create a deterministic storage key for relation chunk tracking.""" + + return GRAPH_FIELD_SEP.join(sorted((src, tgt))) + + +def parse_relation_chunk_key(key: str) -> tuple[str, str]: + """Parse a relation chunk storage key back into its entity pair.""" + + parts = key.split(GRAPH_FIELD_SEP) + if len(parts) != 2: + raise ValueError(f"Invalid relation chunk key: {key}") + return parts[0], parts[1] + + +def generate_track_id(prefix: str = "upload") -> str: + """Generate a unique tracking ID with timestamp and UUID + + Args: + prefix: Prefix for the track ID (e.g., 'upload', 'insert') + + Returns: + str: Unique tracking ID in format: {prefix}_{timestamp}_{uuid} + """ + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + unique_id = str(uuid.uuid4())[:8] # Use first 8 characters of UUID + return f"{prefix}_{timestamp}_{unique_id}" + + +def get_pinyin_sort_key(text: str) -> str: + """Generate sort key for Chinese pinyin sorting + + This function uses pypinyin for true Chinese pinyin sorting. + If pypinyin is not available, it falls back to simple lowercase string sorting. + + Args: + text: Text to generate sort key for + + Returns: + str: Sort key that can be used for comparison and sorting + """ + if not text: + return "" + + if _PYPINYIN_AVAILABLE: + try: + # Convert Chinese characters to pinyin, keep non-Chinese as-is + pinyin_list = pypinyin.lazy_pinyin(text, style=pypinyin.Style.NORMAL) + return "".join(pinyin_list).lower() + except Exception: + # Silently fall back to simple string sorting on any error + return text.lower() + else: + # pypinyin not available, use simple string sorting + return text.lower() + + +def fix_tuple_delimiter_corruption( + record: str, delimiter_core: str, tuple_delimiter: str +) -> str: + """ + Fix various forms of tuple_delimiter corruption from LLM output. + + This function handles missing or replaced characters around the core delimiter. + It fixes common corruption patterns where the LLM output doesn't match the expected + tuple_delimiter format. + + Args: + record: The text record to fix + delimiter_core: The core delimiter (e.g., "S" from "<|#|>") + tuple_delimiter: The complete tuple delimiter (e.g., "<|#|>") + + Returns: + The corrected record with proper tuple_delimiter format + """ + if not record or not delimiter_core or not tuple_delimiter: + return record + + # Escape the delimiter core for regex use + escaped_delimiter_core = re.escape(delimiter_core) + + # Fix: <|##|> -> <|#|>, <|#||#|> -> <|#|>, <|#|||#|> -> <|#|> + record = re.sub( + rf"<\|{escaped_delimiter_core}\|*?{escaped_delimiter_core}\|>", + tuple_delimiter, + record, + ) + + # Fix: <|\#|> -> <|#|> + record = re.sub( + rf"<\|\\{escaped_delimiter_core}\|>", + tuple_delimiter, + record, + ) + + # Fix: <|> -> <|#|>, <||> -> <|#|> + record = re.sub( + r"<\|+>", + tuple_delimiter, + record, + ) + + # Fix: -> <|#|>, <|#|Y> -> <|#|>, -> <|#|>, <||#||> -> <|#|> (one extra characters outside pipes) + record = re.sub( + rf"<.?\|{escaped_delimiter_core}\|.?>", + tuple_delimiter, + record, + ) + + # Fix: <#>, <#|>, <|#> -> <|#|> (missing one or both pipes) + record = re.sub( + rf"<\|?{escaped_delimiter_core}\|?>", + tuple_delimiter, + record, + ) + + # Fix: -> <|#|>, <|#X> -> <|#|> (one pipe is replaced by other character) + record = re.sub( + rf"<[^|]{escaped_delimiter_core}\|>|<\|{escaped_delimiter_core}[^|]>", + tuple_delimiter, + record, + ) + + # Fix: <|#| -> <|#|>, <|#|| -> <|#|> (missing closing >) + record = re.sub( + rf"<\|{escaped_delimiter_core}\|+(?!>)", + tuple_delimiter, + record, + ) + + # Fix <|#: -> <|#|> (missing closing >) + record = re.sub( + rf"<\|{escaped_delimiter_core}:(?!>)", + tuple_delimiter, + record, + ) + + # Fix: <||#> -> <|#|> (double pipe at start, missing pipe at end) + record = re.sub( + rf"<\|+{escaped_delimiter_core}>", + tuple_delimiter, + record, + ) + + # Fix: <|| -> <|#|> + record = re.sub( + r"<\|\|(?!>)", + tuple_delimiter, + record, + ) + + # Fix: |#|> -> <|#|> (missing opening <) + record = re.sub( + rf"(?", + tuple_delimiter, + record, + ) + + # Fix: <|#|>| -> <|#|> ( this is a fix for: <|#|| -> <|#|> ) + record = re.sub( + rf"<\|{escaped_delimiter_core}\|>\|", + tuple_delimiter, + record, + ) + + # Fix: ||#|| -> <|#|> (double pipes on both sides without angle brackets) + record = re.sub( + rf"\|\|{escaped_delimiter_core}\|\|", + tuple_delimiter, + record, + ) + + return record + + +def create_prefixed_exception(original_exception: Exception, prefix: str) -> Exception: + """ + Safely create a prefixed exception that adapts to all error types. + + Args: + original_exception: The original exception. + prefix: The prefix to add. + + Returns: + A new exception with the prefix, maintaining the original exception type if possible. + """ + try: + # Method 1: Try to reconstruct using original arguments. + if hasattr(original_exception, "args") and original_exception.args: + args = list(original_exception.args) + # Find the first string argument and prefix it. This is safer for + # exceptions like OSError where the first arg is an integer (errno). + found_str = False + for i, arg in enumerate(args): + if isinstance(arg, str): + args[i] = f"{prefix}: {arg}" + found_str = True + break + + # If no string argument is found, prefix the first argument's string representation. + if not found_str: + args[0] = f"{prefix}: {args[0]}" + + return type(original_exception)(*args) + else: + # Method 2: If no args, try single parameter construction. + return type(original_exception)(f"{prefix}: {str(original_exception)}") + except (TypeError, ValueError, AttributeError) as construct_error: + # Method 3: If reconstruction fails, wrap it in a RuntimeError. + # This is the safest fallback, as attempting to create the same type + # with a single string can fail if the constructor requires multiple arguments. + return RuntimeError( + f"{prefix}: {type(original_exception).__name__}: {str(original_exception)} " + f"(Original exception could not be reconstructed: {construct_error})" + ) + + +def convert_to_user_format( + entities_context: list[dict], + relations_context: list[dict], + chunks: list[dict], + references: list[dict], + query_mode: str, + entity_id_to_original: dict = None, + relation_id_to_original: dict = None, +) -> dict[str, Any]: + """Convert internal data format to user-friendly format using original database data""" + + # Convert entities format using original data when available + formatted_entities = [] + for entity in entities_context: + entity_name = entity.get("entity", "") + + # Try to get original data first + original_entity = None + if entity_id_to_original and entity_name in entity_id_to_original: + original_entity = entity_id_to_original[entity_name] + + if original_entity: + # Use original database data + formatted_entities.append( + { + "entity_name": original_entity.get("entity_name", entity_name), + "entity_type": original_entity.get("entity_type", "UNKNOWN"), + "description": original_entity.get("description", ""), + "source_id": original_entity.get("source_id", ""), + "file_path": original_entity.get("file_path", "unknown_source"), + "created_at": original_entity.get("created_at", ""), + } + ) + else: + # Fallback to LLM context data (for backward compatibility) + formatted_entities.append( + { + "entity_name": entity_name, + "entity_type": entity.get("type", "UNKNOWN"), + "description": entity.get("description", ""), + "source_id": entity.get("source_id", ""), + "file_path": entity.get("file_path", "unknown_source"), + "created_at": entity.get("created_at", ""), + } + ) + + # Convert relationships format using original data when available + formatted_relationships = [] + for relation in relations_context: + entity1 = relation.get("entity1", "") + entity2 = relation.get("entity2", "") + relation_key = (entity1, entity2) + + # Try to get original data first + original_relation = None + if relation_id_to_original and relation_key in relation_id_to_original: + original_relation = relation_id_to_original[relation_key] + + if original_relation: + # Use original database data + formatted_relationships.append( + { + "src_id": original_relation.get("src_id", entity1), + "tgt_id": original_relation.get("tgt_id", entity2), + "description": original_relation.get("description", ""), + "keywords": original_relation.get("keywords", ""), + "weight": original_relation.get("weight", 1.0), + "source_id": original_relation.get("source_id", ""), + "file_path": original_relation.get("file_path", "unknown_source"), + "created_at": original_relation.get("created_at", ""), + } + ) + else: + # Fallback to LLM context data (for backward compatibility) + formatted_relationships.append( + { + "src_id": entity1, + "tgt_id": entity2, + "description": relation.get("description", ""), + "keywords": relation.get("keywords", ""), + "weight": relation.get("weight", 1.0), + "source_id": relation.get("source_id", ""), + "file_path": relation.get("file_path", "unknown_source"), + "created_at": relation.get("created_at", ""), + } + ) + + # Convert chunks format (chunks already contain complete data) + formatted_chunks = [] + for i, chunk in enumerate(chunks): + chunk_data = { + "reference_id": chunk.get("reference_id", ""), + "content": chunk.get("content", ""), + "file_path": chunk.get("file_path", "unknown_source"), + "chunk_id": chunk.get("chunk_id", ""), + } + formatted_chunks.append(chunk_data) + + logger.debug( + f"[convert_to_user_format] Formatted {len(formatted_chunks)}/{len(chunks)} chunks" + ) + + # Build basic metadata (metadata details will be added by calling functions) + metadata = { + "query_mode": query_mode, + "keywords": { + "high_level": [], + "low_level": [], + }, # Placeholder, will be set by calling functions + } + + return { + "status": "success", + "message": "Query processed successfully", + "data": { + "entities": formatted_entities, + "relationships": formatted_relationships, + "chunks": formatted_chunks, + "references": references, + }, + "metadata": metadata, + } + + +def generate_reference_list_from_chunks( + chunks: list[dict], +) -> tuple[list[dict], list[dict]]: + """ + Generate reference list from chunks, prioritizing by occurrence frequency. + + This function extracts file_paths from chunks, counts their occurrences, + sorts by frequency and first appearance order, creates reference_id mappings, + and builds a reference_list structure. + + Args: + chunks: List of chunk dictionaries with file_path information + + Returns: + tuple: (reference_list, updated_chunks_with_reference_ids) + - reference_list: List of dicts with reference_id and file_path + - updated_chunks_with_reference_ids: Original chunks with reference_id field added + """ + if not chunks: + return [], [] + + # 1. Extract all valid file_paths and count their occurrences + file_path_counts = {} + for chunk in chunks: + file_path = chunk.get("file_path", "") + if file_path and file_path != "unknown_source": + file_path_counts[file_path] = file_path_counts.get(file_path, 0) + 1 + + # 2. Sort file paths by frequency (descending), then by first appearance order + # Create a list of (file_path, count, first_index) tuples + file_path_with_indices = [] + seen_paths = set() + for i, chunk in enumerate(chunks): + file_path = chunk.get("file_path", "") + if file_path and file_path != "unknown_source" and file_path not in seen_paths: + file_path_with_indices.append((file_path, file_path_counts[file_path], i)) + seen_paths.add(file_path) + + # Sort by count (descending), then by first appearance index (ascending) + sorted_file_paths = sorted(file_path_with_indices, key=lambda x: (-x[1], x[2])) + unique_file_paths = [item[0] for item in sorted_file_paths] + + # 3. Create mapping from file_path to reference_id (prioritized by frequency) + file_path_to_ref_id = {} + for i, file_path in enumerate(unique_file_paths): + file_path_to_ref_id[file_path] = str(i + 1) + + # 4. Add reference_id field to each chunk + updated_chunks = [] + for chunk in chunks: + chunk_copy = chunk.copy() + file_path = chunk_copy.get("file_path", "") + if file_path and file_path != "unknown_source": + chunk_copy["reference_id"] = file_path_to_ref_id[file_path] + else: + chunk_copy["reference_id"] = "" + updated_chunks.append(chunk_copy) + + # 5. Build reference_list + reference_list = [] + for i, file_path in enumerate(unique_file_paths): + reference_list.append({"reference_id": str(i + 1), "file_path": file_path}) + + return reference_list, updated_chunks diff --git a/lightrag/utils.py b/lightrag/utils.py index fc8bc401c9..f798a4fdc5 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -410,120 +410,88 @@ class TaskState: @dataclass class EmbeddingFunc: - """Embedding function wrapper with dimension validation - - This class wraps an embedding function to ensure that the output embeddings have the correct dimension. - If wrapped multiple times, the inner wrappers will be automatically unwrapped to prevent - configuration conflicts where inner wrapper settings would override outer wrapper settings. - - Using functools.partial for parameter binding: - A common pattern is to use functools.partial to pre-bind model and host parameters - to an embedding function. When the base embedding function is already decorated with - @wrap_embedding_func_with_attrs (e.g., ollama_embed), use `.func` to access the - original unwrapped function to avoid double wrapping: - - Example: - from functools import partial - - # ❌ Wrong - causes double wrapping (inner EmbeddingFunc still executes) - func=partial(ollama_embed, embed_model="bge-m3:latest", host="http://localhost:11434") - - # ✅ Correct - access the unwrapped function via .func - func=partial(ollama_embed.func, embed_model="bge-m3:latest", host="http://localhost:11434") - - Args: - embedding_dim: Expected dimension of the embeddings(For dimension checking and workspace data isolation in vector DB) - func: The actual embedding function to wrap - max_token_size: Enable embedding token limit checking for description summarization(Set embedding_token_limit in LightRAG) - send_dimensions: Whether to inject embedding_dim argument to underlying function - model_name: Model name for implementing workspace data isolation in vector DB - """ + """Embedding function wrapper with dimension validation + allow_extra_vectors flag (fixes #2549)""" embedding_dim: int func: callable max_token_size: int | None = None send_dimensions: bool = False - model_name: str | None = ( - None # Model name for implementing workspace data isolation in vector DB - ) + model_name: str | None = None + allow_extra_vectors: bool = False # ← our fix — default is safe fail-fast def __post_init__(self): - """Unwrap nested EmbeddingFunc to prevent double wrapping issues. - - When an EmbeddingFunc wraps another EmbeddingFunc, the inner wrapper's - __call__ preprocessing would override the outer wrapper's settings. - This method detects and unwraps nested EmbeddingFunc instances to ensure - that only the outermost wrapper's configuration is applied. - """ - # Check if func is already an EmbeddingFunc instance and unwrap it - max_unwrap_depth = 3 # Safety limit to prevent infinite loops + """Unwrap nested EmbeddingFunc to prevent double wrapping issues.""" + max_unwrap_depth = 3 unwrap_count = 0 while isinstance(self.func, EmbeddingFunc): unwrap_count += 1 if unwrap_count > max_unwrap_depth: raise ValueError( - f"EmbeddingFunc unwrap depth exceeded {max_unwrap_depth}. " - "Possible circular reference detected." + "EmbeddingFunc unwrap depth exceeded — possible circular reference." ) - # Unwrap to get the original function self.func = self.func.func if unwrap_count > 0: logger.warning( - f"Detected nested EmbeddingFunc wrapping (depth: {unwrap_count}), " - "auto-unwrapped to prevent configuration conflicts. " - "Consider using .func to access the unwrapped function directly." + f"Detected nested EmbeddingFunc (depth: {unwrap_count}), auto-unwrapped." ) async def __call__(self, *args, **kwargs) -> np.ndarray: - # Only inject embedding_dim when send_dimensions is True + # === Original upstream logic (unchanged) === if self.send_dimensions: - # Check if user provided embedding_dim parameter - if "embedding_dim" in kwargs: - user_provided_dim = kwargs["embedding_dim"] - # If user's value differs from class attribute, output warning - if ( - user_provided_dim is not None - and user_provided_dim != self.embedding_dim - ): - logger.warning( - f"Ignoring user-provided embedding_dim={user_provided_dim}, " - f"using declared embedding_dim={self.embedding_dim} from decorator" - ) - - # Inject embedding_dim from decorator + if ( + "embedding_dim" in kwargs + and kwargs["embedding_dim"] is not None + and kwargs["embedding_dim"] != self.embedding_dim + ): + logger.warning( + f"Ignoring user-provided embedding_dim, using {self.embedding_dim}" + ) kwargs["embedding_dim"] = self.embedding_dim - # Check if underlying function supports max_token_size and inject if not provided if self.max_token_size is not None and "max_token_size" not in kwargs: sig = inspect.signature(self.func) if "max_token_size" in sig.parameters: kwargs["max_token_size"] = self.max_token_size - # Call the actual embedding function + # Call the real function result = await self.func(*args, **kwargs) - # Validate embedding dimensions using total element count - total_elements = result.size # Total number of elements in the numpy array + # === Dimension validation (unchanged) === + total_elements = result.size expected_dim = self.embedding_dim - - # Check if total elements can be evenly divided by embedding_dim if total_elements % expected_dim != 0: raise ValueError( - f"Embedding dimension mismatch detected: " - f"total elements ({total_elements}) cannot be evenly divided by " - f"expected dimension ({expected_dim}). " + f"Embedding dimension mismatch: got {total_elements}, expected multiple of {expected_dim}" ) - # Optional: Verify vector count matches input text count + # === Our approved fix for MinerU/RAGAnything === actual_vectors = total_elements // expected_dim if args and isinstance(args[0], (list, tuple)): expected_vectors = len(args[0]) if actual_vectors != expected_vectors: - raise ValueError( - f"Vector count mismatch: " - f"expected {expected_vectors} vectors but got {actual_vectors} vectors (from embedding result)." - ) + provider = self.model_name or "unknown" + if self.allow_extra_vectors: + logger.warning( + f"Vector count mismatch (provider={provider}): " + f"expected {expected_vectors} but got {actual_vectors}. " + f"Adjusting... (allow_extra_vectors=True — normal for MinerU)." + ) + # Over = slice | Under = pad with last vector (safe fallback) + reshaped = result.reshape(-1, expected_dim) + if actual_vectors > expected_vectors: + result = reshaped[:expected_vectors] + else: + pad = np.tile( + reshaped[-1:], (expected_vectors - actual_vectors, 1) + ) + result = np.vstack([reshaped, pad]) + else: + raise ValueError( + f"Vector count mismatch (provider={provider}): " + f"expected {expected_vectors} vectors but got {actual_vectors}. " + f"Set allow_extra_vectors=True on EmbeddingFunc if using MinerU/RAGAnything." + ) return result diff --git a/lightrag_webui/src/App.tsx b/lightrag_webui/src/App.tsx index b8ae023d7a..54075388c6 100644 --- a/lightrag_webui/src/App.tsx +++ b/lightrag_webui/src/App.tsx @@ -15,6 +15,7 @@ import GraphViewer from '@/features/GraphViewer' import DocumentManager from '@/features/DocumentManager' import RetrievalTesting from '@/features/RetrievalTesting' import ApiSite from '@/features/ApiSite' +import SanitizeData from '@/features/SanitizeData' import { Tabs, TabsContent } from '@/components/ui/Tabs' @@ -207,6 +208,9 @@ function App() { + + + diff --git a/lightrag_webui/src/App_original.tsx b/lightrag_webui/src/App_original.tsx new file mode 100644 index 0000000000..b8ae023d7a --- /dev/null +++ b/lightrag_webui/src/App_original.tsx @@ -0,0 +1,230 @@ +import { useState, useCallback, useEffect, useRef } from 'react' +import ThemeProvider from '@/components/ThemeProvider' +import TabVisibilityProvider from '@/contexts/TabVisibilityProvider' +import ApiKeyAlert from '@/components/ApiKeyAlert' +import StatusIndicator from '@/components/status/StatusIndicator' +import { SiteInfo, webuiPrefix } from '@/lib/constants' +import { useBackendState, useAuthStore } from '@/stores/state' +import { useSettingsStore } from '@/stores/settings' +import { getAuthStatus } from '@/api/lightrag' +import SiteHeader from '@/features/SiteHeader' +import { InvalidApiKeyError, RequireApiKeError } from '@/api/lightrag' +import { ZapIcon } from 'lucide-react' + +import GraphViewer from '@/features/GraphViewer' +import DocumentManager from '@/features/DocumentManager' +import RetrievalTesting from '@/features/RetrievalTesting' +import ApiSite from '@/features/ApiSite' + +import { Tabs, TabsContent } from '@/components/ui/Tabs' + +function App() { + const message = useBackendState.use.message() + const enableHealthCheck = useSettingsStore.use.enableHealthCheck() + const currentTab = useSettingsStore.use.currentTab() + const [apiKeyAlertOpen, setApiKeyAlertOpen] = useState(false) + const [initializing, setInitializing] = useState(true) // Add initializing state + const versionCheckRef = useRef(false); // Prevent duplicate calls in Vite dev mode + const healthCheckInitializedRef = useRef(false); // Prevent duplicate health checks in Vite dev mode + + const handleApiKeyAlertOpenChange = useCallback((open: boolean) => { + setApiKeyAlertOpen(open) + if (!open) { + useBackendState.getState().clear() + } + }, []) + + // Track component mount status with useRef + const isMountedRef = useRef(true); + + // Set up mount/unmount status tracking + useEffect(() => { + isMountedRef.current = true; + + // Handle page reload/unload + const handleBeforeUnload = () => { + isMountedRef.current = false; + }; + + window.addEventListener('beforeunload', handleBeforeUnload); + + return () => { + isMountedRef.current = false; + window.removeEventListener('beforeunload', handleBeforeUnload); + }; + }, []); + + // Health check - can be disabled + useEffect(() => { + // Health check function + const performHealthCheck = async () => { + try { + // Only perform health check if component is still mounted + if (isMountedRef.current) { + await useBackendState.getState().check(); + } + } catch (error) { + console.error('Health check error:', error); + } + }; + + // Set health check function in the store + useBackendState.getState().setHealthCheckFunction(performHealthCheck); + + if (!enableHealthCheck || apiKeyAlertOpen) { + useBackendState.getState().clearHealthCheckTimer(); + return; + } + + // On first mount or when enableHealthCheck becomes true and apiKeyAlertOpen is false, + // perform an immediate health check and start the timer + if (!healthCheckInitializedRef.current) { + healthCheckInitializedRef.current = true; + } + + // Start/reset the health check timer using the store + useBackendState.getState().resetHealthCheckTimer(); + + // Component unmount cleanup + return () => { + useBackendState.getState().clearHealthCheckTimer(); + }; + }, [enableHealthCheck, apiKeyAlertOpen]); + + // Version check - independent and executed only once + useEffect(() => { + const checkVersion = async () => { + // Prevent duplicate calls in Vite dev mode + if (versionCheckRef.current) return; + versionCheckRef.current = true; + + // Check if version info was already obtained in login page + const versionCheckedFromLogin = sessionStorage.getItem('VERSION_CHECKED_FROM_LOGIN') === 'true'; + if (versionCheckedFromLogin) { + setInitializing(false); // Skip initialization if already checked + return; + } + + try { + setInitializing(true); // Start initialization + + // Get version info + const token = localStorage.getItem('LIGHTRAG-API-TOKEN'); + const status = await getAuthStatus(); + + // If auth is not configured and a new token is returned, use the new token + if (!status.auth_configured && status.access_token) { + useAuthStore.getState().login( + status.access_token, // Use the new token + true, // Guest mode + status.core_version, + status.api_version, + status.webui_title || null, + status.webui_description || null + ); + } else if (token && (status.core_version || status.api_version || status.webui_title || status.webui_description)) { + // Otherwise use the old token (if it exists) + const isGuestMode = status.auth_mode === 'disabled' || useAuthStore.getState().isGuestMode; + useAuthStore.getState().login( + token, + isGuestMode, + status.core_version, + status.api_version, + status.webui_title || null, + status.webui_description || null + ); + } + + // Set flag to indicate version info has been checked + sessionStorage.setItem('VERSION_CHECKED_FROM_LOGIN', 'true'); + } catch (error) { + console.error('Failed to get version info:', error); + } finally { + // Ensure initializing is set to false even if there's an error + setInitializing(false); + } + }; + + // Execute version check + checkVersion(); + }, []); // Empty dependency array ensures it only runs once on mount + + const handleTabChange = useCallback( + (tab: string) => useSettingsStore.getState().setCurrentTab(tab as any), + [] + ) + + useEffect(() => { + if (message) { + if (message.includes(InvalidApiKeyError) || message.includes(RequireApiKeError)) { + setApiKeyAlertOpen(true) + } + } + }, [message]) + + return ( + + + {initializing ? ( + // Loading state while initializing with simplified header +

+ {/* Simplified header during initialization - matches SiteHeader structure */} +
+ + + {/* Empty middle section to maintain layout */} +
+
+ + {/* Empty right section to maintain layout */} + +
+ + {/* Loading indicator in content area */} +
+
+
+

Initializing...

+
+
+
+ ) : ( + // Main content after initialization +
+ + +
+ + + + + + + + + + + + +
+
+ {enableHealthCheck && } + +
+ )} + + + ) +} + +export default App diff --git a/lightrag_webui/src/components/retrieval/ChatMessage.tsx b/lightrag_webui/src/components/retrieval/ChatMessage.tsx index 0abb3b95fc..87eb43131b 100644 --- a/lightrag_webui/src/components/retrieval/ChatMessage.tsx +++ b/lightrag_webui/src/components/retrieval/ChatMessage.tsx @@ -222,13 +222,15 @@ export const ChatMessage = ({ {/* Main content display */} {finalDisplayContent && (
-
.base]:overflow-x-auto [&_sup]:text-[0.75em] [&_sup]:align-[0.1em] [&_sup]:leading-[0] [&_sub]:text-[0.75em] [&_sub]:align-[-0.2em] [&_sub]:leading-[0] [&_mark]:bg-yellow-200 [&_mark]:dark:bg-yellow-800 [&_u]:underline [&_del]:line-through [&_ins]:underline [&_ins]:decoration-green-500 [&_.footnotes]:mt-8 [&_.footnotes]:pt-4 [&_.footnotes]:border-t [&_.footnotes_ol]:text-sm [&_.footnotes_li]:my-1 ${ - message.role === 'user' ? 'text-primary-foreground' : 'text-foreground' - } ${ - message.role === 'user' - ? '[&_.footnotes]:border-primary-foreground/30 [&_a[href^="#fn"]]:text-primary-foreground [&_a[href^="#fn"]]:no-underline [&_a[href^="#fn"]]:hover:underline [&_a[href^="#fnref"]]:text-primary-foreground [&_a[href^="#fnref"]]:no-underline [&_a[href^="#fnref"]]:hover:underline' - : '[&_.footnotes]:border-border [&_a[href^="#fn"]]:text-primary [&_a[href^="#fn"]]:no-underline [&_a[href^="#fn"]]:hover:underline [&_a[href^="#fnref"]]:text-primary [&_a[href^="#fnref"]]:no-underline [&_a[href^="#fnref"]]:hover:underline' - }`}> +
.base]:overflow-x-auto [&_sup]:text-[0.75em] [&_sup]:align-[0.1em] [&_sup]:leading-[0] [&_sub]:text-[0.75em] [&_sub]:align-[-0.2em] [&_sub]:leading-[0] [&_mark]:bg-yellow-200 [&_mark]:dark:bg-yellow-800 [&_u]:underline [&_del]:line-through [&_ins]:underline [&_ins]:decoration-green-500 [&_.footnotes]:mt-8 [&_.footnotes]:pt-4 [&_.footnotes]:border-t [&_.footnotes_ol]:text-sm [&_.footnotes_li]:my-1 ${ + message.role === 'user' ? 'text-primary-foreground' : 'text-foreground' + } ${ + message.role === 'user' + ? '[&_.footnotes]:border-primary-foreground/30 [&_a[href^="#fn"]]:text-primary-foreground [&_a[href^="#fn"]]:no-underline [&_a[href^="#fn"]]:hover:underline [&_a[href^="#fnref"]]:text-primary-foreground [&_a[href^="#fnref"]]:no-underline [&_a[href^="#fnref"]]:hover:underline' + : '[&_.footnotes]:border-border [&_a[href^="#fn"]]:text-primary [&_a[href^="#fn"]]:no-underline [&_a[href^="#fn"]]:hover:underline [&_a[href^="#fnref"]]:text-primary [&_a[href^="#fnref"]]:no-underline [&_a[href^="#fnref"]]:hover:underline' + }`} + > { - // Only show detailed errors in development environment if (process.env.NODE_ENV === 'development') { console.warn('KaTeX rendering error in main content:', error, 'for LaTeX:', latex); } diff --git a/lightrag_webui/src/components/retrieval/ChatMessage_original.tsx b/lightrag_webui/src/components/retrieval/ChatMessage_original.tsx new file mode 100644 index 0000000000..e490941ca8 --- /dev/null +++ b/lightrag_webui/src/components/retrieval/ChatMessage_original.tsx @@ -0,0 +1,508 @@ +import { ReactNode, useEffect, useMemo, useRef, memo, useState } from 'react' // Import useMemo +import { Message } from '@/api/lightrag' +import useTheme from '@/hooks/useTheme' +import { cn } from '@/lib/utils' + +import ReactMarkdown from 'react-markdown' +import remarkGfm from 'remark-gfm' +import rehypeReact from 'rehype-react' +import rehypeRaw from 'rehype-raw' +import remarkMath from 'remark-math' +import mermaid from 'mermaid' +import { remarkFootnotes } from '@/utils/remarkFootnotes' + + +import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter' +import { oneLight, oneDark } from 'react-syntax-highlighter/dist/cjs/styles/prism' + +import { LoaderIcon, ChevronDownIcon } from 'lucide-react' +import { useTranslation } from 'react-i18next' + +// KaTeX configuration options interface +interface KaTeXOptions { + errorColor?: string; + throwOnError?: boolean; + displayMode?: boolean; + strict?: boolean; + trust?: boolean; + errorCallback?: (error: string, latex: string) => void; +} + +export type MessageWithError = Message & { + id: string // Unique identifier for stable React keys + isError?: boolean + isThinking?: boolean // Flag to indicate if the message is in a "thinking" state + /** + * Indicates if the mermaid diagram in this message has been rendered. + * Used to persist the rendering state across updates and prevent flickering. + */ + mermaidRendered?: boolean + /** + * Indicates if the LaTeX formulas in this message are complete and ready for rendering. + * Used to prevent red error text during streaming of incomplete LaTeX formulas. + */ + latexRendered?: boolean +} + +// Restore original component definition and export +export const ChatMessage = ({ + message, + isTabActive = true +}: { + message: MessageWithError + isTabActive?: boolean +}) => { + const { t } = useTranslation() + const { theme } = useTheme() + const [katexPlugin, setKatexPlugin] = useState<((options?: KaTeXOptions) => any) | null>(null) + const [isThinkingExpanded, setIsThinkingExpanded] = useState(false) + + // Directly use props passed from the parent. + const { thinkingContent, displayContent, thinkingTime, isThinking } = message + + // Reset expansion state when new thinking starts + useEffect(() => { + if (isThinking) { + // When thinking starts, always reset to collapsed state + setIsThinkingExpanded(false) + } + }, [isThinking, message.id]) + + // The content to display is now non-ambiguous. + const finalThinkingContent = thinkingContent + // For user messages, displayContent will be undefined, so we fall back to content. + // For assistant messages, we prefer displayContent but fallback to content for backward compatibility + const finalDisplayContent = message.role === 'user' + ? message.content + : (displayContent !== undefined ? displayContent : (message.content || '')) + + // Load KaTeX rehype plugin dynamically + // Note: KaTeX extensions (mhchem, copy-tex) are imported statically in main.tsx + useEffect(() => { + const loadKaTeX = async () => { + try { + const { default: rehypeKatex } = await import('rehype-katex'); + setKatexPlugin(() => rehypeKatex); + } catch (error) { + console.error('Failed to load KaTeX plugin:', error); + setKatexPlugin(null); + } + }; + + loadKaTeX(); + }, []); + + const mainMarkdownComponents = useMemo(() => ({ + code: (props: any) => { + const { inline, className, children, ...restProps } = props; + const match = /language-(\w+)/.exec(className || ''); + const language = match ? match[1] : undefined; + + // Handle math blocks ($$...$$) - provide better container and styling + if (language === 'math' && !inline) { + return ( +
+
{children}
+
+ ); + } + + // Handle inline math ($...$) - ensure proper inline display + if (language === 'math' && inline) { + return ( + + {children} + + ); + } + + // Handle all other code (inline and block) + return ( + + {children} + + ); + }, + p: ({ children }: { children?: ReactNode }) =>
{children}
, + h1: ({ children }: { children?: ReactNode }) =>

{children}

, + h2: ({ children }: { children?: ReactNode }) =>

{children}

, + h3: ({ children }: { children?: ReactNode }) =>

{children}

, + h4: ({ children }: { children?: ReactNode }) =>

{children}

, + ul: ({ children }: { children?: ReactNode }) =>
    {children}
, + ol: ({ children }: { children?: ReactNode }) =>
    {children}
, + li: ({ children }: { children?: ReactNode }) =>
  • {children}
  • + }), [message.mermaidRendered, message.role]); + + const thinkingMarkdownComponents = useMemo(() => ({ + code: (props: any) => () + }), [message.mermaidRendered, message.role]); + + return ( +
    + {/* Thinking process display - only for assistant messages */} + {/* Always render to prevent layout shift when switching tabs */} + {message.role === 'assistant' && (isThinking || thinkingTime !== null) && ( +
    +
    { + // Allow expansion when there's thinking content, even during thinking process + if (finalThinkingContent && finalThinkingContent.trim() !== '') { + setIsThinkingExpanded(!isThinkingExpanded) + } + }} + > + {isThinking ? ( + <> + {/* Only show spinner animation in active tab to save resources */} + {isTabActive && } + {t('retrievePanel.chatMessage.thinking')} + + ) : ( + typeof thinkingTime === 'number' && {t('retrievePanel.chatMessage.thinkingTime', { time: thinkingTime })} + )} + {/* Show chevron when there's thinking content, even during thinking process */} + {finalThinkingContent && finalThinkingContent.trim() !== '' && } +
    + {/* Show thinking content when expanded and content exists, even during thinking process */} + {isThinkingExpanded && finalThinkingContent && finalThinkingContent.trim() !== '' && ( +
    + {isThinking && ( +
    + {t('retrievePanel.chatMessage.thinkingInProgress', 'Thinking in progress...')} +
    + )} + { + // Only show detailed errors in development environment + if (process.env.NODE_ENV === 'development') { + console.warn('KaTeX rendering error in thinking content:', error, 'for LaTeX:', latex); + } + } + }] as any] : []), + rehypeReact + ]} + skipHtml={false} + components={thinkingMarkdownComponents} + > + {finalThinkingContent} + +
    + )} +
    + )} + {/* Main content display */} + {finalDisplayContent && ( +
    + .base]:overflow-x-auto [&_sup]:text-[0.75em] [&_sup]:align-[0.1em] [&_sup]:leading-[0] [&_sub]:text-[0.75em] [&_sub]:align-[-0.2em] [&_sub]:leading-[0] [&_mark]:bg-yellow-200 [&_mark]:dark:bg-yellow-800 [&_u]:underline [&_del]:line-through [&_ins]:underline [&_ins]:decoration-green-500 [&_.footnotes]:mt-8 [&_.footnotes]:pt-4 [&_.footnotes]:border-t [&_.footnotes_ol]:text-sm [&_.footnotes_li]:my-1 ${ + message.role === 'user' ? 'text-primary-foreground' : 'text-foreground' + } ${ + message.role === 'user' + ? '[&_.footnotes]:border-primary-foreground/30 [&_a[href^="#fn"]]:text-primary-foreground [&_a[href^="#fn"]]:no-underline [&_a[href^="#fn"]]:hover:underline [&_a[href^="#fnref"]]:text-primary-foreground [&_a[href^="#fnref"]]:no-underline [&_a[href^="#fnref"]]:hover:underline' + : '[&_.footnotes]:border-border [&_a[href^="#fn"]]:text-primary [&_a[href^="#fn"]]:no-underline [&_a[href^="#fn"]]:hover:underline [&_a[href^="#fnref"]]:text-primary [&_a[href^="#fnref"]]:no-underline [&_a[href^="#fnref"]]:hover:underline' + }`} + remarkPlugins={[remarkGfm, remarkFootnotes, remarkMath]} + rehypePlugins={[ + rehypeRaw, + ...((katexPlugin && (message.latexRendered ?? true)) ? [[ + katexPlugin, + { + errorColor: theme === 'dark' ? '#ef4444' : '#dc2626', + throwOnError: false, + displayMode: false, + strict: false, + trust: true, + // Add silent error handling to avoid console noise + errorCallback: (error: string, latex: string) => { + // Only show detailed errors in development environment + if (process.env.NODE_ENV === 'development') { + console.warn('KaTeX rendering error in main content:', error, 'for LaTeX:', latex); + } + } + } + ] as any] : []), + rehypeReact + ]} + skipHtml={false} + components={mainMarkdownComponents} + > + {finalDisplayContent} + +
    + )} + {/* Loading indicator - only show in active tab */} + {isTabActive && (() => { + // More comprehensive loading state check + const hasVisibleContent = finalDisplayContent && finalDisplayContent.trim() !== ''; + const isLoadingState = !hasVisibleContent && !isThinking && !thinkingTime; + return isLoadingState && + })()} +
    + ) +} + +// Remove the incorrect memo export line + +interface CodeHighlightProps { + inline?: boolean + className?: string + children?: ReactNode + renderAsDiagram?: boolean // Flag to indicate if rendering as diagram should be attempted + messageRole?: 'user' | 'assistant' // Message role for context-aware styling +} + + + +// Check if it is a large JSON +const isLargeJson = (language: string | undefined, content: string | undefined): boolean => { + if (!content || language !== 'json') return false; + return content.length > 5000; // JSON larger than 5KB is considered large JSON +}; + +// Memoize the CodeHighlight component +const CodeHighlight = memo(({ inline, className, children, renderAsDiagram = false, messageRole, ...props }: CodeHighlightProps) => { + const { theme } = useTheme(); + const [hasRendered, setHasRendered] = useState(false); // State to track successful render + const match = className?.match(/language-(\w+)/); + const language = match ? match[1] : undefined; + const mermaidRef = useRef(null); + const debounceTimerRef = useRef | null>(null); // Use ReturnType for better typing + + // Get the content string, check if it is a large JSON + const contentStr = String(children || '').replace(/\n$/, ''); + const isLargeJsonBlock = isLargeJson(language, contentStr); + + // Handle Mermaid rendering with debounce + useEffect(() => { + // Effect should run when renderAsDiagram becomes true or hasRendered changes. + // The actual rendering logic inside checks language and hasRendered state. + if (renderAsDiagram && !hasRendered && language === 'mermaid' && mermaidRef.current) { + const container = mermaidRef.current; // Capture ref value + + // Clear previous timer if dependencies change before timeout (e.g., renderAsDiagram flips quickly) + if (debounceTimerRef.current) { + clearTimeout(debounceTimerRef.current); + } + + debounceTimerRef.current = setTimeout(() => { + if (!container) return; // Container might have unmounted + + // Double check hasRendered state inside timeout, in case it changed rapidly + if (hasRendered) return; + + try { + // Initialize mermaid config + mermaid.initialize({ + startOnLoad: false, + theme: theme === 'dark' ? 'dark' : 'default', + securityLevel: 'loose', + suppressErrorRendering: true, + }); + + // Show loading indicator + container.innerHTML = '
    '; + + // Preprocess mermaid content + const rawContent = String(children).replace(/\n$/, '').trim(); + + // Heuristic check for potentially complete graph definition + const looksPotentiallyComplete = rawContent.length > 10 && ( + rawContent.startsWith('graph') || + rawContent.startsWith('sequenceDiagram') || + rawContent.startsWith('classDiagram') || + rawContent.startsWith('stateDiagram') || + rawContent.startsWith('gantt') || + rawContent.startsWith('pie') || + rawContent.startsWith('flowchart') || + rawContent.startsWith('erDiagram') + ); + + if (!looksPotentiallyComplete) { + console.log('Mermaid content might be incomplete, skipping render attempt:', rawContent); + // Optionally keep loading indicator or show a message + // container.innerHTML = '

    Waiting for complete diagram...

    '; + return; + } + + const processedContent = rawContent + .split('\n') + .map(line => { + const trimmedLine = line.trim(); + if (trimmedLine.startsWith('subgraph')) { + const parts = trimmedLine.split(' '); + if (parts.length > 1) { + const title = parts.slice(1).join(' ').replace(/["']/g, ''); + return `subgraph "${title}"`; + } + } + return trimmedLine; + }) + .filter(line => !line.trim().startsWith('linkStyle')) + .join('\n'); + + const mermaidId = `mermaid-${Date.now()}`; + mermaid.render(mermaidId, processedContent) + .then(({ svg, bindFunctions }) => { + // Check ref and hasRendered state again inside async callback + if (mermaidRef.current === container && !hasRendered) { + container.innerHTML = svg; + setHasRendered(true); // Mark as rendered successfully + if (bindFunctions) { + try { + bindFunctions(container); + } catch (bindError) { + console.error('Mermaid bindFunctions error:', bindError); + container.innerHTML += '

    Diagram interactions might be limited.

    '; + } + } + } else if (mermaidRef.current !== container) { + console.log('Mermaid container changed before rendering completed.'); + } + }) + .catch(error => { + console.error('Mermaid rendering promise error (debounced):', error); + console.error('Failed content (debounced):', processedContent); + if (mermaidRef.current === container) { + const errorMessage = error instanceof Error ? error.message : String(error); + const errorPre = document.createElement('pre'); + errorPre.className = 'text-red-500 text-xs whitespace-pre-wrap break-words'; + errorPre.textContent = `Mermaid diagram error: ${errorMessage}\n\nContent:\n${processedContent}`; + container.innerHTML = ''; + container.appendChild(errorPre); + } + }); + + } catch (error) { + console.error('Mermaid synchronous error (debounced):', error); + console.error('Failed content (debounced):', String(children)); + if (mermaidRef.current === container) { + const errorMessage = error instanceof Error ? error.message : String(error); + const errorPre = document.createElement('pre'); + errorPre.className = 'text-red-500 text-xs whitespace-pre-wrap break-words'; + errorPre.textContent = `Mermaid diagram setup error: ${errorMessage}`; + container.innerHTML = ''; + container.appendChild(errorPre); + } + } + }, 300); // Debounce delay + } + + // Cleanup function to clear the timer on unmount or before re-running effect + return () => { + if (debounceTimerRef.current) { + clearTimeout(debounceTimerRef.current); + } + }; + // Dependencies: renderAsDiagram ensures effect runs when diagram should be shown. + // Dependencies include all values used inside the effect to satisfy exhaustive-deps. + // The !hasRendered check prevents re-execution of render logic after success. + }, [renderAsDiagram, hasRendered, language, children, theme]); // Add children and theme back + + // For large JSON, skip syntax highlighting completely and use a simple pre tag + if (isLargeJsonBlock) { + return ( +
    +        {contentStr}
    +      
    + ); + } + + // Render based on language type + // If it's a mermaid language block and rendering as diagram is not requested (e.g., incomplete stream), display as plain text + if (language === 'mermaid' && !renderAsDiagram) { + return ( + + {contentStr} + + ); + } + + // If it's a mermaid language block and the message is complete, render as diagram + if (language === 'mermaid') { + // Container for Mermaid diagram + return
    ; + } + + + // ReactMarkdown determines inline vs block based on markdown syntax + // Inline code: `code` (no className with language) + // Block code: ```language (has className like "language-js") + // If there's no language className and no explicit inline prop, it's likely inline code + const isInline = inline ?? !className?.startsWith('language-'); + + // Generate dynamic inline code styles based on message role and theme + const getInlineCodeStyles = () => { + if (messageRole === 'user') { + // User messages have dark background (bg-primary), need light inline code + return theme === 'dark' + ? 'bg-primary-foreground/20 text-primary-foreground border border-primary-foreground/30' + : 'bg-primary-foreground/20 text-primary-foreground border border-primary-foreground/30'; + } else { + // Assistant messages have light background (bg-muted), need contrasting inline code + return theme === 'dark' + ? 'bg-muted-foreground/20 text-muted-foreground border border-muted-foreground/30' + : 'bg-slate-200 text-slate-800 border border-slate-300'; + } + }; + + // Handle non-Mermaid code blocks + return !isInline ? ( + + {contentStr} + + ) : ( + // Handle inline code with context-aware styling + + {children} + + ); +}); + +// Assign display name for React DevTools +CodeHighlight.displayName = 'CodeHighlight'; diff --git a/lightrag_webui/src/components/retrieval/QuerySettings.tsx b/lightrag_webui/src/components/retrieval/QuerySettings.tsx index 0b0096c0a8..7850c27ab1 100644 --- a/lightrag_webui/src/components/retrieval/QuerySettings.tsx +++ b/lightrag_webui/src/components/retrieval/QuerySettings.tsx @@ -213,15 +213,23 @@ export default function QuerySettings() { value={querySettings.chunk_top_k ?? ''} onChange={(e) => { const value = e.target.value - handleChange('chunk_top_k', value === '' ? '' : parseInt(value) || 0) + const numValue = value === '' ? '' : parseInt(value) || 0 + handleChange('chunk_top_k', numValue) + + // Auto-enable KG-only mode when chunk_top_k is 0 + if (numValue === 0) { + handleChange('only_kg_context', true) + } else { + handleChange('only_kg_context', false) + } }} onBlur={(e) => { const value = e.target.value if (value === '' || isNaN(parseInt(value))) { - handleChange('chunk_top_k', 20) + handleChange('chunk_top_k', 0) } }} - min={1} + min={0} placeholder={t('retrievePanel.querySettings.chunkTopKPlaceholder')} className="h-9 flex-1 pr-2 [&::-webkit-outer-spin-button]:appearance-none [&::-webkit-inner-spin-button]:appearance-none [-moz-appearance:textfield]" /> diff --git a/lightrag_webui/src/components/retrieval/QuerySettings_original.tsx b/lightrag_webui/src/components/retrieval/QuerySettings_original.tsx new file mode 100644 index 0000000000..0b0096c0a8 --- /dev/null +++ b/lightrag_webui/src/components/retrieval/QuerySettings_original.tsx @@ -0,0 +1,457 @@ +import { useCallback, useMemo } from 'react' +import { QueryMode, QueryRequest } from '@/api/lightrag' +// Removed unused import for Text component +import Checkbox from '@/components/ui/Checkbox' +import Input from '@/components/ui/Input' +import UserPromptInputWithHistory from '@/components/ui/UserPromptInputWithHistory' +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/Card' +import { + Select, + SelectContent, + SelectGroup, + SelectItem, + SelectTrigger, + SelectValue +} from '@/components/ui/Select' +import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from '@/components/ui/Tooltip' +import { useSettingsStore } from '@/stores/settings' +import { useTranslation } from 'react-i18next' +import { RotateCcw } from 'lucide-react' + +export default function QuerySettings() { + const { t } = useTranslation() + const querySettings = useSettingsStore((state) => state.querySettings) + const userPromptHistory = useSettingsStore((state) => state.userPromptHistory) + + const handleChange = useCallback((key: keyof QueryRequest, value: any) => { + useSettingsStore.getState().updateQuerySettings({ [key]: value }) + }, []) + + const handleSelectFromHistory = useCallback((prompt: string) => { + handleChange('user_prompt', prompt) + }, [handleChange]) + + const handleDeleteFromHistory = useCallback((index: number) => { + const newHistory = [...userPromptHistory] + newHistory.splice(index, 1) + useSettingsStore.getState().setUserPromptHistory(newHistory) + }, [userPromptHistory]) + + // Default values for reset functionality + const defaultValues = useMemo(() => ({ + mode: 'mix' as QueryMode, + top_k: 40, + chunk_top_k: 20, + max_entity_tokens: 6000, + max_relation_tokens: 8000, + max_total_tokens: 30000 + }), []) + + const handleReset = useCallback((key: keyof typeof defaultValues) => { + handleChange(key, defaultValues[key]) + }, [handleChange, defaultValues]) + + // Reset button component + const ResetButton = ({ onClick, title }: { onClick: () => void; title: string }) => ( + + + + + + +

    {title}

    +
    +
    +
    + ) + + return ( + + + {t('retrievePanel.querySettings.parametersTitle')} + {t('retrievePanel.querySettings.parametersDescription')} + + +
    +
    + {/* User Prompt - Moved to top for better dropdown space */} + <> + + + + + + +

    {t('retrievePanel.querySettings.userPromptTooltip')}

    +
    +
    +
    +
    + handleChange('user_prompt', value)} + onSelectFromHistory={handleSelectFromHistory} + onDeleteFromHistory={handleDeleteFromHistory} + history={userPromptHistory} + placeholder={t('retrievePanel.querySettings.userPromptPlaceholder')} + className="h-9" + /> +
    + + + {/* Query Mode */} + <> + + + + + + +

    {t('retrievePanel.querySettings.queryModeTooltip')}

    +
    +
    +
    +
    + + handleReset('mode')} + title="Reset to default (Mix)" + /> +
    + + + {/* Top K */} + <> + + + + + + +

    {t('retrievePanel.querySettings.topKTooltip')}

    +
    +
    +
    +
    + { + const value = e.target.value + handleChange('top_k', value === '' ? '' : parseInt(value) || 0) + }} + onBlur={(e) => { + const value = e.target.value + if (value === '' || isNaN(parseInt(value))) { + handleChange('top_k', 40) + } + }} + min={1} + placeholder={t('retrievePanel.querySettings.topKPlaceholder')} + className="h-9 flex-1 pr-2 [&::-webkit-outer-spin-button]:appearance-none [&::-webkit-inner-spin-button]:appearance-none [-moz-appearance:textfield]" + /> + handleReset('top_k')} + title="Reset to default" + /> +
    + + + {/* Chunk Top K */} + <> + + + + + + +

    {t('retrievePanel.querySettings.chunkTopKTooltip')}

    +
    +
    +
    +
    + { + const value = e.target.value + handleChange('chunk_top_k', value === '' ? '' : parseInt(value) || 0) + }} + onBlur={(e) => { + const value = e.target.value + if (value === '' || isNaN(parseInt(value))) { + handleChange('chunk_top_k', 20) + } + }} + min={1} + placeholder={t('retrievePanel.querySettings.chunkTopKPlaceholder')} + className="h-9 flex-1 pr-2 [&::-webkit-outer-spin-button]:appearance-none [&::-webkit-inner-spin-button]:appearance-none [-moz-appearance:textfield]" + /> + handleReset('chunk_top_k')} + title="Reset to default" + /> +
    + + + {/* Max Entity Tokens */} + <> + + + + + + +

    {t('retrievePanel.querySettings.maxEntityTokensTooltip')}

    +
    +
    +
    +
    + { + const value = e.target.value + handleChange('max_entity_tokens', value === '' ? '' : parseInt(value) || 0) + }} + onBlur={(e) => { + const value = e.target.value + if (value === '' || isNaN(parseInt(value))) { + handleChange('max_entity_tokens', 6000) + } + }} + min={1} + placeholder={t('retrievePanel.querySettings.maxEntityTokensPlaceholder')} + className="h-9 flex-1 pr-2 [&::-webkit-outer-spin-button]:appearance-none [&::-webkit-inner-spin-button]:appearance-none [-moz-appearance:textfield]" + /> + handleReset('max_entity_tokens')} + title="Reset to default" + /> +
    + + + {/* Max Relation Tokens */} + <> + + + + + + +

    {t('retrievePanel.querySettings.maxRelationTokensTooltip')}

    +
    +
    +
    +
    + { + const value = e.target.value + handleChange('max_relation_tokens', value === '' ? '' : parseInt(value) || 0) + }} + onBlur={(e) => { + const value = e.target.value + if (value === '' || isNaN(parseInt(value))) { + handleChange('max_relation_tokens', 8000) + } + }} + min={1} + placeholder={t('retrievePanel.querySettings.maxRelationTokensPlaceholder')} + className="h-9 flex-1 pr-2 [&::-webkit-outer-spin-button]:appearance-none [&::-webkit-inner-spin-button]:appearance-none [-moz-appearance:textfield]" + /> + handleReset('max_relation_tokens')} + title="Reset to default" + /> +
    + + + {/* Max Total Tokens */} + <> + + + + + + +

    {t('retrievePanel.querySettings.maxTotalTokensTooltip')}

    +
    +
    +
    +
    + { + const value = e.target.value + handleChange('max_total_tokens', value === '' ? '' : parseInt(value) || 0) + }} + onBlur={(e) => { + const value = e.target.value + if (value === '' || isNaN(parseInt(value))) { + handleChange('max_total_tokens', 30000) + } + }} + min={1} + placeholder={t('retrievePanel.querySettings.maxTotalTokensPlaceholder')} + className="h-9 flex-1 pr-2 [&::-webkit-outer-spin-button]:appearance-none [&::-webkit-inner-spin-button]:appearance-none [-moz-appearance:textfield]" + /> + handleReset('max_total_tokens')} + title="Reset to default" + /> +
    + + + {/* Toggle Options */} + <> +
    + + + + + + +

    {t('retrievePanel.querySettings.enableRerankTooltip')}

    +
    +
    +
    + handleChange('enable_rerank', checked)} + /> +
    + +
    + + + + + + +

    {t('retrievePanel.querySettings.onlyNeedContextTooltip')}

    +
    +
    +
    + { + handleChange('only_need_context', checked) + if (checked) { + handleChange('only_need_prompt', false) + } + }} + /> +
    + +
    + + + + + + +

    {t('retrievePanel.querySettings.onlyNeedPromptTooltip')}

    +
    +
    +
    + { + handleChange('only_need_prompt', checked) + if (checked) { + handleChange('only_need_context', false) + } + }} + /> +
    + +
    + + + + + + +

    {t('retrievePanel.querySettings.streamResponseTooltip')}

    +
    +
    +
    + handleChange('stream', checked)} + /> +
    + + +
    +
    +
    +
    + ) +} diff --git a/lightrag_webui/src/contexts/TabVisibilityProvider.tsx b/lightrag_webui/src/contexts/TabVisibilityProvider.tsx index e64cf14602..38e9a6b02e 100644 --- a/lightrag_webui/src/contexts/TabVisibilityProvider.tsx +++ b/lightrag_webui/src/contexts/TabVisibilityProvider.tsx @@ -18,6 +18,7 @@ export const TabVisibilityProvider: React.FC = ({ ch // Initialize visibility state with all tabs visible const [visibleTabs, setVisibleTabs] = useState>(() => ({ 'documents': true, + 'sanitize-data': true, 'knowledge-graph': true, 'retrieval': true, 'api': true @@ -28,6 +29,7 @@ export const TabVisibilityProvider: React.FC = ({ ch setVisibleTabs((prev) => ({ ...prev, 'documents': true, + 'sanitize-data': true, 'knowledge-graph': true, 'retrieval': true, 'api': true diff --git a/lightrag_webui/src/contexts/TabVisibilityProvider_original.tsx b/lightrag_webui/src/contexts/TabVisibilityProvider_original.tsx new file mode 100644 index 0000000000..e64cf14602 --- /dev/null +++ b/lightrag_webui/src/contexts/TabVisibilityProvider_original.tsx @@ -0,0 +1,59 @@ +import React, { useState, useEffect, useMemo } from 'react'; +import { TabVisibilityContext } from './context'; +import { TabVisibilityContextType } from './types'; +import { useSettingsStore } from '@/stores/settings'; + +interface TabVisibilityProviderProps { + children: React.ReactNode; +} + +/** + * Provider component for the TabVisibility context + * Manages the visibility state of tabs throughout the application + */ +export const TabVisibilityProvider: React.FC = ({ children }) => { + // Get current tab from settings store + const currentTab = useSettingsStore.use.currentTab(); + + // Initialize visibility state with all tabs visible + const [visibleTabs, setVisibleTabs] = useState>(() => ({ + 'documents': true, + 'knowledge-graph': true, + 'retrieval': true, + 'api': true + })); + + // Keep all tabs visible because we use CSS to control TAB visibility instead of React + useEffect(() => { + setVisibleTabs((prev) => ({ + ...prev, + 'documents': true, + 'knowledge-graph': true, + 'retrieval': true, + 'api': true + })); + }, [currentTab]); + + // Create the context value with memoization to prevent unnecessary re-renders + const contextValue = useMemo( + () => ({ + visibleTabs, + setTabVisibility: (tabId: string, isVisible: boolean) => { + setVisibleTabs((prev) => ({ + ...prev, + [tabId]: isVisible, + })); + }, + isTabVisible: (tabId: string) => !!visibleTabs[tabId], + }), + [visibleTabs] + ); + + return ( + + {children} + + ); +}; + +export default TabVisibilityProvider; diff --git a/lightrag_webui/src/features/SanitizeData.tsx b/lightrag_webui/src/features/SanitizeData.tsx new file mode 100644 index 0000000000..f2d0260f75 --- /dev/null +++ b/lightrag_webui/src/features/SanitizeData.tsx @@ -0,0 +1,2740 @@ +// lightrag_webui/src/features/SanitizeData.tsx +import React, { useState, useEffect, useRef } from 'react'; +import axios from 'axios'; + +// const API_BASE = 'http://localhost:9621'; +const API_BASE = `${window.location.protocol}//${window.location.hostname}:9621`; + +export default function SanitizeData() { + const [entities, setEntities] = useState([]); + const [filterText, setFilterText] = useState(''); + const [currentPage, setCurrentPage] = useState(1); + const [selectedEntities, setSelectedEntities] = useState([]); + const [firstEntity, setFirstEntity] = useState(null); + const [targetEntity, setTargetEntity] = useState(''); + const [entityType, setEntityType] = useState(''); + + // Select Target Modal state + const [selectTargetModalOpen, setSelectTargetModalOpen] = useState(false); + const [selectedAction, setSelectedAction] = useState<'merge' | 'createRel' | null>(null); // Tracks which button triggered the modal + const [tempTarget, setTempTarget] = useState(''); // Temporary target selected in the modal + + // For controls that exist but are not being used. + // const [descriptionStrategy, setDescriptionStrategy] = useState('join_unique'); + // const [sourceIdStrategy, setSourceIdStrategy] = useState('join_unique'); + + // Dropdown suggestions = currently selected entities + const targetOptions = [...selectedEntities].sort((a, b) => a.localeCompare(b)); + + // Store fetched details: entityName → { desc, type, sourceId, filePath, relatedEntities, relationships } + const [entityDetails, setEntityDetails] = useState>({}); + + // Loading state (optional but nice UX) + const [loadingDetails, setLoadingDetails] = useState([]); + + const listContainerRef = useRef(null); + const [rowsPerPage, setRowsPerPage] = useState(20); // initial guess + + // Modal state for editing relationships + const [editRelationshipsModalOpen, setEditRelationshipsModalOpen] = useState(false); + const [editingEntityForRel, setEditingEntityForRel] = useState(null); + + // Temporary edits for relationships while modal is open + const [relationshipEdits, setRelationshipEdits] = useState>({}); + + // Unique entity types from selected entities + const [uniqueEntityTypes, setUniqueEntityTypes] = useState([]); + + // Modal state for selecting entity type + const [selectTypeModalOpen, setSelectTypeModalOpen] = useState(false); + + // State for the "Select Type" Modal + const [allEntityTypes, setAllEntityTypes] = useState([]); + const [selectedModalType, setSelectedModalType] = useState(''); + const [loadingTypes, setLoadingTypes] = useState(false); + const [modalFilterText, setModalFilterText] = useState(''); + const [typeSelectionContext, setTypeSelectionContext] = useState<'main' | 'create' | 'edit'>('main'); + const modalInputRef = useRef(null); + const typeItemRefs = useRef([]); + + const [typesLoading, setTypesLoading] = useState(true); + const [filterMode, setFilterMode] = useState<'none' | 'selected' | 'type' | 'orphan'>('none'); + const [typeFilteredEntities, setTypeFilteredEntities] = useState([]); + const [entityTypeMap, setEntityTypeMap] = useState>({}); + + const [orphanFilteredEntities, setOrphanFilteredEntities] = useState([]); + const [entityOrphanMap, setEntityOrphanMap] = useState>({}); + + // Create Entity Modal state + const [createEntityModalOpen, setCreateEntityModalOpen] = useState(false); + const [createEntityName, setCreateEntityName] = useState(''); + const [createEntityDescription, setCreateEntityDescription] = useState(''); + const [createEntityType, setCreateEntityType] = useState(''); + const [createEntitySourceId, setCreateEntitySourceId] = useState(''); + const [createError, setCreateError] = useState(null); // For error messages + + // Edit Entity Modal state (replaces old description-only) + const [editEntityModalOpen, setEditEntityModalOpen] = useState(false); + const [editEntityOriginalName, setEditEntityOriginalName] = useState(null); // For rename detection + const [editEntityName, setEditEntityName] = useState(''); + const [editEntityDescription, setEditEntityDescription] = useState(''); + const [editEntityType, setEditEntityType] = useState(''); + const [editEntitySourceId, setEditEntitySourceId] = useState(''); + const [editError, setEditError] = useState(null); + + // Create Relationship Modal state + const [createRelModalOpen, setCreateRelModalOpen] = useState(false); + const [createRelDescription, setCreateRelDescription] = useState(''); + const [createRelKeywords, setCreateRelKeywords] = useState(''); + const [createRelWeight, setCreateRelWeight] = useState(1.0); + const [createRelError, setCreateRelError] = useState(null); + + const createNameRef = useRef(null); + const createSourceRef = useRef(null); + const editSourceRef = useRef(null); + + const editNameRef = useRef(null); + + const filteredModalTypes = allEntityTypes.filter((type) => + type.toLowerCase().includes(modalFilterText.toLowerCase()) + ); + + const filterInputRef = useRef(null); + const previousFilterModeRef = useRef<'none' | 'selected' | 'type' | 'orphan'>('none'); + + // Batch Relationships Modal state + const [batchModalOpen, setBatchModalOpen] = useState(false); + const [batchSource, setBatchSource] = useState(null); + const [batchTargets, setBatchTargets] = useState([]); + const [batchTemplate, setBatchTemplate] = useState(null); // Will hold selected template object + const [batchTemplates, setBatchTemplates] = useState([]); // All templates from entity + const [batchPreview, setBatchPreview] = useState([]); // Generated previews + const [batchErrors, setBatchErrors] = useState([]); // e.g., duplicates + const [loadingBatch, setLoadingBatch] = useState(false); + + const [savingEntity, setSavingEntity] = useState(false); + const [savingRelationships, setSavingRelationships] = useState(false); + + + // Ref that always holds the current filterMode (fixes stale hotkey closures) + const filterModeRef = useRef(filterMode); + useEffect(() => { + filterModeRef.current = filterMode; + }, [filterMode]); + + + + const openVideoTutorial = () => { + window.open( + 'https://youtu.be/70iZxleULYY?si=PTq8S6fYIQpnEX8d', + '_blank', + 'noopener,noreferrer' + ); + }; + + // Build entityTypeMap and entityOrphanMap with single fetch per entity + const fetchEntityDetails = async (entityList: string[]) => { + try { + const typeMap: Record = {}; + const orphanMap: Record = {}; + await Promise.all( + entityList.map(async (name: string) => { + try { + const detailRes = await axios.get( + `${API_BASE}/graphs?label=${encodeURIComponent(name)}&max_depth=1&max_nodes=2` + ); + // Find main node by id (robust to order) + const mainNode = detailRes.data.nodes?.find((node: any) => node.id === name); + const type = mainNode?.properties?.entity_type || ''; + typeMap[name] = type; + + // Detect orphan from the same response + const isOrphan = (detailRes.data.nodes?.length || 0) <= 1 && (detailRes.data.edges?.length || 0) === 0; + orphanMap[name] = isOrphan; + } catch (err) { + console.error(`Error fetching details for ${name}:`, err); + } + }) + ); + setEntityTypeMap(typeMap); + setEntityOrphanMap(orphanMap); + // console.log('Types loaded:', Object.keys(typeMap).length); // Debug + // console.log('Orphans loaded:', Object.values(orphanMap).filter(Boolean).length); // Debug + } catch (err) { + console.error('Failed to fetch entity details:', err); + } finally { + setTypesLoading(false); + } + }; + + // Helper: force-refresh every entity that is currently selected + // This keeps all detail panels + any open modals in sync + const refreshAllSelectedDetails = async () => { + if (selectedEntities.length === 0) return; + + await Promise.all( + selectedEntities.map((name) => fetchEntityDetail(name, true)) + ); + }; + + // For loading the Select Type modal window + const fetchAllTypes = async () => { + setLoadingTypes(true); + try { + // 1. Get all entity names + const listRes = await axios.get(`${API_BASE}/graph/label/list`); + const entityNames = listRes.data as string[]; + + // 2. Fetch types for each name (in parallel) + // Note: If you have thousands of entities, we may need to chunk this later. + const typeSet = new Set(); + + await Promise.all( + entityNames.map(async (name) => { + try { + const detailRes = await axios.get( + `${API_BASE}/graphs?label=${encodeURIComponent(name)}&max_depth=1&max_nodes=1` + ); + const type = detailRes.data.nodes?.[0]?.properties?.entity_type; + if (type) typeSet.add(type); + } catch (err) { + console.error(`Error fetching type for ${name}:`, err); + } + }) + ); + + setAllEntityTypes(Array.from(typeSet).sort()); + } catch (err) { + console.error('Failed to fetch full type list:', err); + } finally { + setLoadingTypes(false); + } + }; + + useEffect(() => { + if (selectTypeModalOpen) { + fetchAllTypes(); + } + }, [selectTypeModalOpen]); + + // Fetch entities + useEffect(() => { + const fetchEntities = async () => { + try { + const response = await axios.get(`${API_BASE}/graph/label/list`); + const sorted = (response.data as string[]).sort((a, b) => + a.toLowerCase().localeCompare(b.toLowerCase()) + ); + setEntities(sorted); + fetchEntityDetails(sorted); // ← Pass sorted here + } catch (err) { + console.error('Failed to load entities:', err); + } + }; + fetchEntities(); + }, []); + + // Calculate visible rows dynamically + useEffect(() => { + const updateRowsPerPage = () => { + if (!listContainerRef.current) return; + + const container = listContainerRef.current; + const containerHeight = container.clientHeight; + const firstRow = container.querySelector('div.grid'); + const rowHeight = firstRow ? firstRow.getBoundingClientRect().height : 36; // fallback + + const headerHeight = 42; + const availableHeight = containerHeight - headerHeight; + const calculated = Math.max(5, Math.floor(availableHeight / rowHeight)); + + setRowsPerPage(calculated); + }; + + updateRowsPerPage(); + window.addEventListener('resize', updateRowsPerPage); + return () => window.removeEventListener('resize', updateRowsPerPage); + }, [entities, filterText]); + + const filteredEntities = entities.filter((e) => + e.toLowerCase().includes(filterText.toLowerCase()) + ); + + const totalPages = Math.max(1, Math.ceil(filteredEntities.length / rowsPerPage)); + const startIndex = (currentPage - 1) * rowsPerPage; + const paginatedEntities = filteredEntities.slice(startIndex, startIndex + rowsPerPage); + + // Fetch details for all selected entities when "Show Desc" is turned on + useEffect(() => { + if (selectedEntities.length > 0) { + selectedEntities.forEach((entityName) => { + fetchEntityDetail(entityName); + }); + } + }, [selectedEntities]); + + // Reset page when filter changes + useEffect(() => { + setCurrentPage(1); + }, [filterText]); + + // Listen for Esc key to cancel any open modal + useEffect(() => { + const handleEscKey = (e: KeyboardEvent) => { + if (e.key === 'Escape') { + if (editRelationshipsModalOpen) { + setEditRelationshipsModalOpen(false); + } else if (selectTypeModalOpen) { + setSelectTypeModalOpen(false); + } else if (createEntityModalOpen) { + setCreateEntityModalOpen(false); + } else if (editEntityModalOpen) { + setEditEntityModalOpen(false); + } else if (createRelModalOpen) { + setCreateRelModalOpen(false); + } else if (selectTargetModalOpen) { + setSelectTargetModalOpen(false); + setTempTarget(''); + setSelectedAction(null); + } else if (batchModalOpen) { + setBatchModalOpen(false); + } + } + }; + document.addEventListener('keydown', handleEscKey); + return () => document.removeEventListener('keydown', handleEscKey); + }, [createRelModalOpen, createEntityModalOpen, editEntityModalOpen, editRelationshipsModalOpen, selectTypeModalOpen, selectTargetModalOpen, batchModalOpen]); + + // Update unique entity types from selected entities' details + useEffect(() => { + const types = new Set(); + + selectedEntities.forEach((name) => { + const type = entityDetails[name]?.type; + if (type) { + types.add(type); + } + }); + setUniqueEntityTypes(Array.from(types).sort()); + }, [selectedEntities, entityDetails]); + + useEffect(() => { + if (selectTypeModalOpen) { + setModalFilterText(''); // Reset the search box + fetchAllTypes(); + } + }, [selectTypeModalOpen]); + + useEffect(() => { + typeItemRefs.current = []; + }, [filteredModalTypes]); + + useEffect(() => { + if (filteredModalTypes.length > 0 && !selectedModalType) { + setSelectedModalType(filteredModalTypes[0]); + } + }, [filteredModalTypes, selectedModalType]); + + useEffect(() => { + if (createEntityModalOpen) { + createNameRef.current?.focus(); + } + }, [createEntityModalOpen]); + + // Ctrl+K → Focus filter (and Show All if we are in any filtered mode) + useEffect(() => { + const handleCtrlK = (e: KeyboardEvent) => { + if ((e.ctrlKey || e.metaKey) && e.key.toLowerCase() === 'k') { + e.preventDefault(); + + // Use the ref so we always see the latest filterMode + if (filterModeRef.current !== 'none') { + handleShowAllAndFocus(); + } else { + focusFilterInput(); + } + } + }; + + document.addEventListener('keydown', handleCtrlK); + return () => document.removeEventListener('keydown', handleCtrlK); + }, []); // ← empty deps now (the ref gives us the live value) + + // Alt + K → Full Reset All (clear selections + everything) + focus filter + // Works the same in EVERY mode (including normal mode) + useEffect(() => { + const handleAltK = (e: KeyboardEvent) => { + if (e.altKey && e.key.toLowerCase() === 'k') { + e.preventDefault(); + handleFullResetAndFocus(); + } + }; + + document.addEventListener('keydown', handleAltK); + return () => document.removeEventListener('keydown', handleAltK); + }, []); // no dependencies needed + + // Ctrl + Enter → "Show Sel. Only" (works from normal, All Of Type, Orphans) + useEffect(() => { + const handleCtrlEnter = (e: KeyboardEvent) => { + if ((e.ctrlKey || e.metaKey) && e.key === 'Enter') { + if (selectedEntities.length > 0 && filterMode !== 'selected') { + e.preventDefault(); + setTargetEntity(''); // ← ADD THIS LINE + handleShowSelectedOnly(); + focusFilterInput(); + } + } + }; + + document.addEventListener('keydown', handleCtrlEnter); + return () => document.removeEventListener('keydown', handleCtrlEnter); + }, [selectedEntities.length, filterMode]); + + // Auto-focus the filter box when the app loads or page is refreshed (F5) + // This runs once after the component mounts + useEffect(() => { + // Tiny delay so the DOM has finished rendering the input + const timer = setTimeout(() => { + filterInputRef.current?.focus(); + // Optional: also select any existing text (nice when you reload) + filterInputRef.current?.select(); + }, 50); + + return () => clearTimeout(timer); + }, []); // ← empty array = run only once on mount + + // Ctrl + ; → open Select Type modal + ALWAYS clear Target Entity on main screen + useEffect(() => { + const handleCtrlSemicolon = (e: KeyboardEvent) => { + if ((e.ctrlKey || e.metaKey) && e.key === ';') { + e.preventDefault(); + + // Ignore inside relationship modals + if (editRelationshipsModalOpen || createRelModalOpen) return; + + // Force-clear Target Entity (setTimeout(0) guarantees the state update happens) + setTimeout(() => { + setTargetEntity(''); + }, 50); + + // Choose correct context + if (createEntityModalOpen) { + setTypeSelectionContext('create'); + } else if (editEntityModalOpen) { + setTypeSelectionContext('edit'); + } else { + setTypeSelectionContext('main'); + } + + setSelectTypeModalOpen(true); + } + }; + + document.addEventListener('keydown', handleCtrlSemicolon); + return () => document.removeEventListener('keydown', handleCtrlSemicolon); + }, [ + createEntityModalOpen, + editEntityModalOpen, + editRelationshipsModalOpen, + createRelModalOpen + ]); + + useEffect(() => { + if (batchModalOpen) { + const fetchTemplates = async () => { + setLoadingBatch(true); + try { + const res = await axios.get(`${API_BASE}/graphs?label=${encodeURIComponent('Relationship Templates')}&max_depth=1&max_nodes=1`); + const mainNode = res.data.nodes?.find((node: any) => node.id === 'Relationship Templates'); + const jsonDesc = mainNode?.properties?.description || '{}'; + const parsed = JSON.parse(jsonDesc); + setBatchTemplates(parsed.relationship_templates || []); + } catch (err) { + console.error('Failed to fetch templates:', err); + setBatchErrors(['Failed to load templates. Check if "Relationship Templates" entity exists.']); + } finally { + setLoadingBatch(false); + } + }; + fetchTemplates(); + // Pre-populate with current selections if in 'selected' mode + if (filterMode === 'selected' && selectedEntities.length >= 2) { + setBatchTargets(selectedEntities); // Start with all selected as potential targets/source + } + } else { + // Reset on close + setBatchSource(null); + setBatchTargets([]); + setBatchTemplate(null); + setBatchPreview([]); + setBatchErrors([]); + } + }, [batchModalOpen, filterMode, selectedEntities]); + + // Pagination handlers + const goToFirst = () => setCurrentPage(1); + const goToPrev = () => setCurrentPage((p) => Math.max(1, p - 1)); + const goToNext = () => setCurrentPage((p) => Math.min(totalPages, p + 1)); + const goToLast = () => setCurrentPage(totalPages); + + const handlePageInputChange = (e: React.ChangeEvent) => { + const val = e.target.value; + if (val === '') return; + const num = parseInt(val, 10); + if (!isNaN(num)) { + setCurrentPage(Math.max(1, Math.min(totalPages, num))); + } + }; + + const handleShowSelectedOnly = () => { + if (selectedEntities.length === 0) { + alert('Please select at least one entity first (check the boxes on the left).'); + return; + } + + setFilterMode('selected'); + setCurrentPage(1); + setFilterText(''); + }; + + const handleShowAllOfType = () => { + if (typesLoading) { + alert('Entity types are still loading. Please wait a moment and try again.'); + return; + } + + // If we're already in type mode → do nothing (button is disabled) + if (filterMode === 'type') return; + + // Otherwise open the modal so user can pick a type + setTypeSelectionContext('main'); + setSelectTypeModalOpen(true); + }; + + // Apply a type filter immediately when the user picks a type from the modal + const applyTypeFilter = (chosenType: string) => { + if (!chosenType) return; + + const entitiesOfType = entities + .filter((name) => entityTypeMap[name] === chosenType) + .sort((a, b) => a.toLowerCase().localeCompare(b.toLowerCase())); + + setEntityType(chosenType); + setTypeFilteredEntities(entitiesOfType); + setFilterMode('type'); + setCurrentPage(1); + setFilterText(''); + }; + + const handleShowOrphans = () => { + if (typesLoading) { // Reuses the same loading state (since orphans load with types) + alert('Entity details are still loading. Please wait a moment and try again.'); + return; + } + + // console.log('Showing orphans'); // Debug + const orphans = entities.filter((name) => entityOrphanMap[name] === true).sort((a, b) => + a.toLowerCase().localeCompare(b.toLowerCase()) + ); + console.log('Orphan count:', orphans.length); // Debug + + setOrphanFilteredEntities(orphans); + setFilterMode('orphan'); + setCurrentPage(1); + setFilterText(''); + }; + + const handleClearSelected = () => { + setSelectedEntities([]); // uncheck all checkboxes + setFirstEntity(null); // deselect "Keep First" radio + }; + + const handleCreateEntity = async () => { + if (!createEntityName.trim()) { + setCreateError('Entity name is required.'); + return; + } + + setCreateError(null); + + try { + const entityData: Record = { + description: createEntityDescription, + entity_type: createEntityType, + }; + if (createEntitySourceId.trim()) { + entityData.source_id = createEntitySourceId; + } + + const response = await axios.post(`${API_BASE}/graph/entity/create`, { + entity_name: createEntityName, + entity_data: entityData, + }); + + if (response.status === 200) { + // Success: Close modal, refresh entities + setCreateEntityModalOpen(false); + // Refresh entity list (re-fetch) + const listRes = await axios.get(`${API_BASE}/graph/label/list`); + const sorted = (listRes.data as string[]).sort((a, b) => + a.toLowerCase().localeCompare(b.toLowerCase()) + ); + setEntities(sorted); + // Also refresh types/orphans map if needed (call fetchEntityTypes again) + // For simplicity, reload page or add a full refresh here if necessary + + fetchSingleEntityDetails(createEntityName); + + alert(response.data.message); // Show success message + } + } catch (err: any) { + console.error('Failed to create entity:', err); + let errorMsg = 'Failed to create entity. Check console for details.'; + if (err.response?.data?.detail) { + errorMsg = err.response.data.detail; // e.g., "Entity 'Walsh' already exists" + } else if (err.response?.data?.message) { + errorMsg = err.response.data.message; // Fallback if API uses "message" + } else if (err.message) { + errorMsg = err.message; // Broader fallback (e.g., network errors) + } + setCreateError(errorMsg); // Show in modal + } + }; + + // Check if the Edit Entity modal has any actual changes + const hasEntityChanges = (): boolean => { + if (!editEntityOriginalName) return false; + + const original = entityDetails[editEntityOriginalName]; + + const nameChanged = editEntityName !== editEntityOriginalName; + const descChanged = editEntityDescription !== (original?.description || ''); + const typeChanged = editEntityType !== (original?.type || ''); + const sourceChanged = editEntitySourceId !== (original?.sourceId || ''); + + return nameChanged || descChanged || typeChanged || sourceChanged; + }; + + const handleRemoveSepDelimiters = async () => { + if (!confirm('Are you sure? This will remove all delimiters from entity and relationship descriptions across the entire index. This cannot be undone.')) { + return; + } + + alert('Starting cleanup... This may take a moment depending on the number of entities.'); + + try { + // Fetch all entity names + const listRes = await axios.get(`${API_BASE}/graph/label/list`); + const entityNames = listRes.data as string[]; + + // Process each entity in parallel (up to browser limits) + await Promise.all( + entityNames.map(async (name) => { + try { + // Fetch full details (entity + relationships) + const detailRes = await axios.get( + `${API_BASE}/graphs?label=${encodeURIComponent(name)}&max_depth=1&max_nodes=20000` + ); + const data = detailRes.data; + + // Find main entity node + const mainNode = data.nodes?.find((node: any) => node.id === name); + if (!mainNode) return; + + // Clean entity description + let cleanDesc = mainNode.properties?.description || ''; + cleanDesc = cleanDesc.replace(//g, ' '); // Replace with space + + // Update entity if description changed + if (cleanDesc !== mainNode.properties?.description) { + const updatePayload = { + entity_name: name, + updated_data: { description: cleanDesc }, + allow_rename: false, + allow_merge: false, + }; + await axios.post(`${API_BASE}/graph/entity/edit`, updatePayload); + } + + // Clean relationships + await Promise.all( + (data.edges || []).map(async (edge: any) => { + let cleanRelDesc = edge.properties?.description || ''; + cleanRelDesc = cleanRelDesc.replace(//g, ' '); + + if (cleanRelDesc !== edge.properties?.description) { + const relPayload = { + source_id: edge.source, + target_id: edge.target, + updated_data: { description: cleanRelDesc }, + }; + await axios.post(`${API_BASE}/graph/relation/edit`, relPayload); + } + }) + ); + } catch (err) { + console.error(`Error cleaning ${name}:`, err); + } + }) + ); + + // Refresh the entire app after cleanup + const refreshedList = await axios.get(`${API_BASE}/graph/label/list`); + const sorted = (refreshedList.data as string[]).sort((a, b) => + a.toLowerCase().localeCompare(b.toLowerCase()) + ); + setEntities(sorted); + fetchEntityDetails(sorted); + await refreshAllSelectedDetails(); + + alert('Cleanup complete! All delimiters removed from descriptions.'); + } catch (err) { + console.error('Cleanup failed:', err); + alert('Cleanup failed. Check console for details.'); + } + }; + + const handleSaveEntity = async () => { + if (!editEntityName.trim()) { + setEditError('Entity name is required.'); + return; + } + // Prevent duplicate name frontend (but backend will confirm) + if (editEntityName !== editEntityOriginalName && entities.includes(editEntityName)) { + setEditError(`Entity name "${editEntityName}" already exists.`); + return; + } + + // Only save if the user actually changed something + if (!hasEntityChanges()) { + setEditError('No changes detected.'); + return; + } + + const original = entityDetails[editEntityOriginalName!]; + + const nameChanged = editEntityName !== editEntityOriginalName; + const descChanged = editEntityDescription !== (original?.description || ''); + const typeChanged = editEntityType !== (original?.type || ''); + const sourceChanged = editEntitySourceId !== (original?.sourceId || ''); + + const hasAnyChange = nameChanged || descChanged || typeChanged || sourceChanged; + + if (!hasAnyChange) { + setEditError('No changes detected.'); + return; + } + + try { + + setSavingEntity(true); + + const updatedData: Record = { + description: editEntityDescription, + entity_type: editEntityType, + source_id: editEntitySourceId || '', + }; + + const allowRename = editEntityName !== editEntityOriginalName; + const allowMerge = false; // Explicitly false to prevent merge on conflict + + if (allowRename) { + updatedData.entity_name = editEntityName; // ← Put new name in updated_data + } + + const payload = { + entity_name: editEntityOriginalName, + updated_data: updatedData, + allow_rename: allowRename, + allow_merge: allowMerge, + }; + + // console.log('Sending edit payload:', JSON.stringify(payload, null, 2)); // Debug + + const response = await axios.post(`${API_BASE}/graph/entity/edit`, payload); + + // console.log('Edit response:', response.data); // Debug + + if (response.status === 200) { + setEditEntityModalOpen(false); + setEditEntityOriginalName(null); + + if (allowRename && editEntityOriginalName) { + // Update selectedEntities with new name + setSelectedEntities((prev) => + prev.map((n) => (n === editEntityOriginalName ? editEntityName : n)) + ); + + // Update firstEntity if it was the old name + if (firstEntity === editEntityOriginalName) { + setFirstEntity(editEntityName); + } + + // Migrate entityDetails to new key + setEntityDetails((prev) => { + if (prev[editEntityOriginalName]) { + const newDetails = { ...prev }; + newDetails[editEntityName] = { ...prev[editEntityOriginalName] }; + delete newDetails[editEntityOriginalName]; + return newDetails; + } + return prev; + }); + } + + // Full refresh of the entity list + type/orphan maps (already there – keep it) + const listRes = await axios.get(`${API_BASE}/graph/label/list`); + const sorted = (listRes.data as string[]).sort((a, b) => + a.toLowerCase().localeCompare(b.toLowerCase()) + ); + setEntities(sorted); + fetchEntityDetails(sorted); + + // ← NEW: Refresh EVERY selected entity so all panels (and any related-entity lists) are up-to-date + await refreshAllSelectedDetails(); + + // Conditional full page reload only if name changed AND at least one other field changed + if (allowRename && (descChanged || typeChanged || sourceChanged)) { + alert(`Reloading page to ensure all mixed name and property changes are fully reflected. + + To avoid a full refresh in the future: + Change the entity name first and save the change. + Then edit the entity properties`); + window.location.reload(); + } + + alert('Entity updated successfully!'); + } else { + setEditError('Update failed with status: ' + response.status); + } + } catch (err: any) { + console.error('Failed to edit entity:', err); + let errorMsg = 'Failed to update entity.'; + if (err.response?.data?.detail) { + errorMsg = err.response.data.detail; + } else if (err.response?.data?.message) { + errorMsg = err.response.data.message; + } else if (err.message) { + errorMsg = err.message; + } + setEditError(errorMsg); + } finally { + setSavingEntity(false); // ← ADD THIS: Stop saving indicator, even on error + } + }; + + const handleDeleteEntities = async () => { + if (selectedEntities.length === 0 || filterMode !== 'selected') return; // Safety check + + if (!confirm(`Are you sure you want to delete ${selectedEntities.length} entity/entities? This cannot be undone.`)) { + return; + } + + try { + let successCount = 0; + let errorMessages: string[] = []; + + for (const entityName of selectedEntities) { + try { + const payload = { entity_name: entityName }; + const response = await axios.delete(`${API_BASE}/documents/delete_entity`, { data: payload }); + + if (response.status === 200) { + successCount++; + } else { + errorMessages.push(`Failed to delete ${entityName} (status: ${response.status})`); + } + } catch (err: any) { + console.error(`Error deleting ${entityName}:`, err); + let msg = `Failed to delete ${entityName}.`; + if (err.response?.status === 404) { + msg = `${entityName} not found.`; + } else if (err.response?.data?.detail) { + msg = err.response.data.detail; + } + errorMessages.push(msg); + } + } + + // Full refresh after deletes + const listRes = await axios.get(`${API_BASE}/graph/label/list`); + const sorted = (listRes.data as string[]).sort((a, b) => + a.toLowerCase().localeCompare(b.toLowerCase()) + ); + setEntities(sorted); + fetchEntityDetails(sorted); + + // Clear selections and firstEntity + setSelectedEntities([]); + setFirstEntity(null); + + // Show summary + if (successCount === selectedEntities.length) { + alert('All selected entities deleted successfully!'); + } else if (successCount > 0) { + alert(`Deleted ${successCount} entity/entities successfully. Errors: ${errorMessages.join(', ')}`); + } else { + alert(`Failed to delete any entities. Errors: ${errorMessages.join(', ')}`); + } + } catch (err) { + console.error('Unexpected error during delete:', err); + alert('An unexpected error occurred during delete.'); + } + }; + + const handleCreateRelationship = async (relTarget: string = targetEntity) => { + + if (!createRelDescription.trim()) { + setCreateRelError('Relationship description is required.'); + return; + } + setCreateRelError(null); + // Derive source and target + if (selectedEntities.length !== 2 || !relTarget || !selectedEntities.includes(relTarget)) { + setCreateRelError('Invalid selection or target.'); + return; + } + const sourceEntity = selectedEntities.find((n) => n !== relTarget) || ''; + try { + const relationData: Record = { + description: createRelDescription, + keywords: createRelKeywords, + weight: createRelWeight, + }; + const payload = { + source_entity: sourceEntity, + target_entity: relTarget, + relation_data: relationData, + }; + const response = await axios.post(`${API_BASE}/graph/relation/create`, payload); + + if (response.status === 200) { + setCreateRelModalOpen(false); + // Refresh details for affected entities + fetchEntityDetail(sourceEntity, true); + fetchEntityDetail(relTarget, true); + alert('Relationship created successfully!'); + } + } catch (err: any) { + console.error('Failed to create relationship:', err); + let errorMsg = 'Failed to create relationship.'; + if (err.response?.status === 400) { + errorMsg = err.response?.data?.detail || 'Invalid request—check if entities exist or duplicate relationship.'; + } else if (err.response?.data?.detail) { + errorMsg = err.response.data.detail; + } else if (err.response?.data?.message) { + errorMsg = err.response.data.message; + } else if (err.message) { + errorMsg = err.message; + } + setCreateRelError(errorMsg); + } + }; + + const handleMergeEntities = async (mergeTarget: string = targetEntity) => { // ← changed: accept param, default to state + if (!mergeTarget || !selectedEntities.includes(mergeTarget)) { + alert('Please select a target entity from the dropdown first.'); + return; + } + const entitiesToChange = selectedEntities.filter((n) => n !== mergeTarget); // ← use mergeTarget + if (entitiesToChange.length === 0) { + alert('No source entities to merge (select at least one besides the target).'); + return; + } + if (!confirm(`Are you sure you want to merge ${entitiesToChange.length} entity/entities into "${mergeTarget}"? Sources will be deleted. This cannot be undone.`)) { // ← use mergeTarget + return; + } + try { + const payload = { + entities_to_change: entitiesToChange, + entity_to_change_into: mergeTarget, // ← use mergeTarget + }; + const response = await axios.post(`${API_BASE}/graph/entities/merge`, payload); + if (response.status === 200) { + // Full refresh after merge + const listRes = await axios.get(`${API_BASE}/graph/label/list`); + const sorted = (listRes.data as string[]).sort((a, b) => + a.toLowerCase().localeCompare(b.toLowerCase()) + ); + setEntities(sorted); + fetchEntityDetails(sorted); + // Refresh target details (shows transferred relations) + fetchEntityDetail(mergeTarget, true); // ← use mergeTarget + // Clear selections and firstEntity + setSelectedEntities([]); + setFirstEntity(null); + setTargetEntity(''); // Optional: clear target + alert(response.data.message || 'Entities merged successfully!'); + } + } catch (err: any) { + console.error('Failed to merge entities:', err); + let errorMsg = 'Failed to merge entities.'; + if (err.response?.status === 400) { + errorMsg = err.response?.data?.detail || 'Invalid request—check if target exists or sources are valid.'; + } else if (err.response?.data?.detail) { + errorMsg = err.response.data.detail; + } else if (err.response?.data?.message) { + errorMsg = err.response.data.message; + } else if (err.message) { + errorMsg = err.message; + } + alert(errorMsg); + } + }; + + const handleGeneratePreview = async () => { + if (!batchSource || !batchTemplate || batchTargets.length <= 1) return; + setLoadingBatch(true); + setBatchPreview([]); + setBatchErrors([]); + const targets = batchTargets.filter((t) => t !== batchSource); + const previews = []; + const errors = []; + try { + // Fetch existing relationships for source + const res = await axios.get(`${API_BASE}/graphs?label=${encodeURIComponent(batchSource)}&max_depth=1&max_nodes=20000`); + const existingEdges = res.data.edges || []; + for (const target of targets) { + const desc = batchTemplate.description_template + .replace('{source}', batchSource) + .replace('{target}', target); + const keywords = batchTemplate.keywords_template + .replace('{source}', batchSource) + .replace('{target}', target); + const weight = batchTemplate.default_weight; + // Check for duplicate + const duplicate = existingEdges.some( + (edge: any) => edge.source === batchSource && edge.target === target + ); + previews.push({ + target, + description: desc, + keywords, + weight, + error: duplicate ? 'Duplicate relationship exists' : null, + }); + if (duplicate) errors.push(`Duplicate for ${target}`); + } + setBatchPreview(previews); + setBatchErrors(errors); + } catch (err) { + console.error('Preview failed:', err); + setBatchErrors(['Failed to generate preview.']); + } finally { + setLoadingBatch(false); + } + }; + + const handleBatchCreate = async () => { + if (batchPreview.some((p) => p.error) && !confirm('Some duplicates found. Proceed anyway?')) return; + setLoadingBatch(true); + try { + let successCount = 0; + for (const prev of batchPreview) { + if (prev.error) continue; // Skip duplicates (or handle as needed) + const payload = { + source_entity: batchSource, + target_entity: prev.target, + relation_data: { + description: prev.description, + keywords: prev.keywords, + weight: prev.weight, + }, + }; + await axios.post(`${API_BASE}/graph/relation/create`, payload); + successCount++; + } + // Refresh affected entities + await refreshAllSelectedDetails(); + fetchEntityDetail(batchSource!, true); + batchTargets.forEach((t) => fetchEntityDetail(t, true)); + alert(`Created ${successCount} relationships successfully!`); + setBatchModalOpen(false); + } catch (err) { + console.error('Batch create failed:', err); + setBatchErrors(['Failed to create some relationships.']); + } finally { + setLoadingBatch(false); + } + }; + + // Toggle an entity in/out of the selection + // This is used both by the checkbox AND by clicking the whole row + const toggleEntitySelection = (entityName: string) => { + if (selectedEntities.includes(entityName)) { + // Deselect + setSelectedEntities(selectedEntities.filter((e) => e !== entityName)); + if (firstEntity === entityName) { + setFirstEntity(null); + } + } else { + // Select + setSelectedEntities([...selectedEntities, entityName]); + } + }; + + // Focus the filter input and select any existing text + // Used after Clear Sel., Reset All, Show All, etc. + const focusFilterInput = () => { + setTimeout(() => { + filterInputRef.current?.focus(); + filterInputRef.current?.select(); + }, 50); + }; + + // Show All + reset everything + put cursor back in filter + const handleShowAllAndFocus = () => { + setFilterMode('none'); + setFilterText(''); + setCurrentPage(1); + setEntityType(''); + setTargetEntity(''); + focusFilterInput(); // reuse the helper you already have + }; + + // Full Reset All + focus filter + // Clears selections, filter text, mode, page, type, target – everything + const handleFullResetAndFocus = () => { + setFilterMode('none'); + setSelectedEntities([]); + setFirstEntity(null); + setFilterText(''); + setCurrentPage(1); + setEntityType(''); + setTargetEntity(''); + focusFilterInput(); // reuses your existing helper + }; + + const fetchSingleEntityDetails = async (name: string) => { + try { + const detailRes = await axios.get( + `${API_BASE}/graphs?label=${encodeURIComponent(name)}&max_depth=1&max_nodes=2` + ); + // Find main node by id (robust to order) + const mainNode = detailRes.data.nodes?.find((node: any) => node.id === name); + const type = mainNode?.properties?.entity_type || ''; + + const isOrphan = (detailRes.data.nodes?.length || 0) <= 1 && (detailRes.data.edges?.length || 0) === 0; + + setEntityTypeMap(prev => ({ ...prev, [name]: type })); + setEntityOrphanMap(prev => ({ ...prev, [name]: isOrphan })); + } catch (err) { + console.error(`Error fetching single entity details for ${name}:`, err); + } + }; + + const fetchEntityDetail = async (entityName: string, force = false) => { + + // console.log(`fetchEntityDetail called for: "${entityName}"`); + + // Skip if we already have it + //if (entityDetails[entityName]) return; + if (entityDetails[entityName] && !force) { + // console.log(`Already have details for "${entityName}" - skipping`); + return; + } + + // console.log(`Fetching details for "${entityName}"...`); + + setLoadingDetails((prev) => [...prev, entityName]); + + try { + console.log("Making axios request..."); + const encodedName = encodeURIComponent(entityName); + const url = `${API_BASE}/graphs?label=${encodedName}&max_depth=1&max_nodes=20000`; + console.log("Request URL:", url); + + const response = await axios.get(url); + console.log("Response received:", response.status, response.data); + + const data = response.data; + + // Parse the response (based on your Python code structure) + let mainDesc = "No description found."; + let mainType = ""; + let mainSourceId = ""; + let mainFilePath = ""; + + const related: any[] = []; + const edges: any[] = []; + + // Process nodes + (data.nodes || []).forEach((node: any) => { + const props = node.properties || {}; + if (node.id === entityName) { + mainDesc = props.description || mainDesc; + mainType = props.entity_type || mainType; + mainSourceId = props.source_id || ""; + mainFilePath = props.file_path || ""; + } else { + related.push({ + name: node.id, + type: props.entity_type || "", + description: props.description || "No description", + }); + } + }); + + // Process edges/relationships + (data.edges || []).forEach((edge: any) => { + const props = edge.properties || {}; + const from = edge.source; + const to = edge.target; + + // Keep the relationship only if Bypass Switch (or whichever entity we queried) is one of the endpoints + if (from === entityName || to === entityName) { + edges.push({ + from, + to, + relation: props.description || "", + weight: props.weight || 1.0, + keywords: props.keywords || "", + }); + } + }); + + // Store the parsed data + setEntityDetails((prev) => ({ + ...prev, + [entityName]: { + type: mainType, + description: mainDesc, + sourceId: mainSourceId, + filePath: mainFilePath, + relatedEntities: related, + relationships: edges, + }, + })); + } catch (err) { + console.error(`Error fetching "${entityName}":`, err); + // Optional: store error state + } finally { + setLoadingDetails((prev) => prev.filter((n) => n !== entityName)); + } + }; + + const openEditEntityModal = (entityName: string) => { + setEditEntityOriginalName(entityName); + setEditEntityName(entityName); + setEditEntityDescription(entityDetails[entityName]?.description || ''); + setEditEntityType(entityDetails[entityName]?.type || ''); + setEditEntitySourceId(entityDetails[entityName]?.sourceId || ''); + setEditError(null); + setEditEntityModalOpen(true); + + // Focus the Entity Name field when modal opens + setTimeout(() => editNameRef.current?.focus(), 100); + }; + + const openEditRelationshipsModal = (entityName: string) => { + setEditingEntityForRel(entityName); + + // Initialize editable copies of relationships + const initialEdits: Record = {}; + entityDetails[entityName]?.relationships?.forEach((rel: any) => { + const key = `${rel.from}-${rel.to}`; // simple unique key + initialEdits[key] = { ...rel }; // shallow copy + }); + + setRelationshipEdits(initialEdits); + + setEditRelationshipsModalOpen(true); + }; + + const triggerGraphRefresh = async () => { + try { + const response = await axios.post(`${API_BASE}/graph/refresh-data`); + if (response.status === 200) { + // console.log("Graph data refresh triggered successfully"); + // Optional: show toast/alert later + } + } catch (err) { + console.error("Failed to trigger graph refresh:", err); + alert("Changes saved, but failed to refresh graph view. Please restart server or try again."); + } + }; + + const saveAllRelationshipChanges = async () => { + if (!editingEntityForRel) return; + + try { + setSavingRelationships(true); + let successCount = 0; + + for (const [key, editedRel] of Object.entries(relationshipEdits)) { + const originalRel = entityDetails[editingEntityForRel].relationships.find( + (r: any) => `${r.from}-${r.to}` === key + ); + + if (!originalRel) continue; + + // Only send if something changed + if ( + editedRel.relation !== originalRel.relation || + editedRel.weight !== originalRel.weight || + editedRel.keywords !== originalRel.keywords + ) { + const payload = { + source_id: editedRel.from, + target_id: editedRel.to, + updated_data: { + description: editedRel.relation, + keywords: editedRel.keywords, + weight: editedRel.weight, + }, + }; + + const res = await axios.post(`${API_BASE}/graph/relation/edit`, payload); + if (res.status === 200) successCount++; + } + } + + if (successCount > 0) { + // Refresh ALL selected entities instead of just the one + await refreshAllSelectedDetails(); + + alert(`Saved ${successCount} relationship change(s) successfully!`); + } else { + alert("No changes detected."); + } + setEditRelationshipsModalOpen(false); + } catch (err) { + console.error("Failed to save relationship changes:", err); + alert("Error saving relationships. Check console."); + } finally { + setSavingRelationships(false); // ← ADD THIS: Stop saving indicator, even on error + } + }; + + const deleteRelationship = async (from: string, to: string) => { + if (!confirm(`Are you sure you want to delete the relationship from ${from} to ${to}? This cannot be undone.`)) { + return; + } + + try { + // Change to DELETE method + correct parameter names + await axios.delete(`${API_BASE}/documents/delete_relation`, { + data: { // Use 'data' for body in DELETE (axios requires this for non-GET methods) + source_entity: from, + target_entity: to, + }, + }); + + // console.log(`Deleted relationship: ${from} → ${to}`); + + // Remove from temp edits + setRelationshipEdits((prev) => { + const newEdits = { ...prev }; + delete newEdits[`${from}-${to}`]; + return newEdits; + }); + + alert("Relationship deleted successfully!"); + } catch (err) { + console.error("Failed to delete relationship:", err); + alert("Error deleting relationship. Check console."); + } finally { + // After successful delete, refresh BOTH sides (and any other selected entities) + await refreshAllSelectedDetails(); + } + }; + + const displayEntities = filterMode === 'selected' + ? [...selectedEntities].sort((a, b) => a.toLowerCase().localeCompare(b.toLowerCase())) + : filterMode === 'type' + ? typeFilteredEntities + : filterMode === 'orphan' + ? orphanFilteredEntities + : paginatedEntities; + + const buttonTabIndex = filterText.length > 0 ? -1 : 0; + + // Auto-focus the first entity row when entering filtered modes + useEffect(() => { + // Only run when the mode has actually changed to a filtered one + if ( + previousFilterModeRef.current !== filterMode && + (filterMode === 'selected' || filterMode === 'type' || filterMode === 'orphan') + ) { + if (displayEntities.length > 0) { + setTimeout(() => { + const firstRow = listContainerRef.current?.querySelector('div[tabindex="0"]') as HTMLElement | null; + firstRow?.focus(); + }, 50); + } + // Update the ref for next time + previousFilterModeRef.current = filterMode; + } + }, [filterMode, displayEntities]); + + + + + return ( +
    + {/* Top row - minimum height to ensure controls are visible */} +
    + {/* Upper Left */} +
    + {filterMode !== 'none' && ( +
    { + if (filterMode === 'type') { + setTypeSelectionContext('main'); + setSelectTypeModalOpen(true); + } else { + handleShowAllAndFocus(); + } + }} + > + {filterMode === 'type' ? ( + <> + Showing only entities of type:{' '} + + {entityType} + {' '} + ({displayEntities.length}) + + ) : filterMode === 'orphan' ? ( + `Showing only orphan entities (${displayEntities.length})` + ) : ( + `Showing only selected entities (${displayEntities.length})` + )} +
    + )} + + {filterMode === 'none' && ( + <> + setFilterText(e.target.value)} + ref={filterInputRef} + /> +
    + + + + {/* Pagination (unchanged) */} +
    + + + +
    + Pg + + /{Math.ceil(filteredEntities.length / rowsPerPage) || 1} +
    + + + +
    +
    + + )} + +
    + {filterMode === 'none' ? ( + + ) : ( +
    // ← invisible placeholder + )} + + {filterMode === 'type' ? ( + + ) : filterMode === 'orphan' ? ( + + ) : ( + + )} + + + +
    +
    + + {/* Upper Right - should now always be visible */} +
    +
    + {/* Left-side action buttons */} +
    + + + + + + + + + +
    + + {/* New Clean Descriptions button */} + + + {/* Video Tutorial button – pushed all the way to the right */} + +
    +
    +
    + + {/* Bottom row */} +
    + {/* Lower Left */} +
    +
    +
    +
    Select
    Entities
    +
    Entity Name
    +
    + +
    + {displayEntities.map((entityName) => ( +
    { + // Ignore the click if the user has selected text (for copying) + if (window.getSelection()?.toString().trim()) { + return; + } + toggleEntitySelection(entityName); + }} + tabIndex={0} + onKeyDown={(e) => { + // Allow global hotkeys (Ctrl+Enter, etc.) to work + if (e.ctrlKey || e.metaKey) return; + + if (e.key === 'Enter' || e.key === ' ') { + e.preventDefault(); + toggleEntitySelection(entityName); + } + }} + > +
    + { + // We still let the native checkbox work normally + toggleEntitySelection(entityName); + }} + onClick={(e) => e.stopPropagation()} /* ← prevents row onClick from firing when clicking the box directly */ + className="h-4 w-4 text-blue-600 rounded" + /> +
    +
    + {entityName} +
    +
    + ))} + + {displayEntities.length === 0 && ( +
    + {selectedEntities.length > 0 + ? "No selected entities to show" + : "No entities match current filter"} +
    + )} +
    +
    +
    + + {/* Lower Right – always shows details for current selection (no "Show Desc" needed) */} +
    +
    + {selectedEntities.length === 0 ? ( +
    + Select one or more entities on the left to view their details +
    + ) : ( +
    + {selectedEntities.map((name) => ( +
    +
    + {name} +
    + + + + +
    +
    + + {loadingDetails.includes(name) ? ( +
    Loading details...
    + ) : entityDetails[name] ? ( +
    + {/* Type */} +
    + Type: {entityDetails[name].type || "No type found."} +
    + + {/* Related Entities count */} +
    + Related Entities: {entityDetails[name].relatedEntities?.length || 0} +
    + + {/* Description */} +
    + Description: +
    + {entityDetails[name].description + ?.split('') + .map((part: string, i: number) => ( +

    + {part.trim() || "No description found."} +

    + )) || "No description found."} +
    +
    + + {/* Source ID */} +
    + Source ID: +
    + {entityDetails[name].sourceId + ?.split('') + .map((id: string, i: number) => ( +

    + {id.trim() || ""} +

    + )) || ""} +
    +
    + + {/* File Path */} +
    + File Path: +
    + {entityDetails[name].filePath + ?.split('') + .map((path: string, i: number) => ( +

    + {path.trim() || ""} +

    + )) || "No file path"} +
    +
    + + {/* Related Entities list */} + {entityDetails[name].relatedEntities?.length > 0 && ( +
    + {entityDetails[name].relatedEntities.map((rel: any, idx: number) => ( +
    + Related Entity {idx + 1}: {rel.name} +
    + (Type: {rel.type || ""}) +
    +
    + Description: + {rel.description + ?.split('') + .map((part: string, j: number) => ( +

    + {part.trim()} +

    + )) || "No description found."} +
    +
    + ))} +
    + )} + + {/* Relationships list */} + {entityDetails[name].relationships?.length > 0 && ( +
    + Relationships: +
    + {entityDetails[name].relationships.map((rel: any, idx: number) => ( +
    +
    + From: {rel.from} +
    To: {rel.to} +
    +
    + Relation: +
    + {rel.relation + ?.split('') + .map((part: string, j: number) => ( +

    + {part.trim() || "No relation description provided."} +

    + )) || "No relation description provided."} +
    +
    +
    + Weight: {rel.weight || 1.0} + {rel.keywords && ( + + Keywords:{' '} + {rel.keywords + .split(',') + .map((kw: string, j: number) => ( + + {kw.trim()} + + ))} + + )} + +
    + +
    +
    + ))} +
    +
    + )} +
    + ) : ( +
    Failed to load details
    + )} +
    + ))} +
    + )} +
    +
    +
    + + {/* Status bar */} +
    +
    + Total entities: {entities.length} + Selected: {selectedEntities.length} + {filterMode !== 'none' && Filtered: {displayEntities.length}} + Orphans: {Object.values(entityOrphanMap).filter(Boolean).length} +
    +
    LightRAG Sanitizer
    +
    + + {/* Edit Entity Modal */} + {editEntityModalOpen && editEntityOriginalName && ( +
    +
    +

    + Edit Entity: {editEntityOriginalName} +

    + {editError && ( +
    + {editError} +
    + )} +
    + {savingEntity && ( +
    +
    +
    +

    Saving changes... Please wait.

    +
    +
    + )} +
    + {/* Entity Name */} +
    + + setEditEntityName(e.target.value)} + /> +
    + {/* Description */} +
    + +