Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1148,8 +1148,8 @@ The Mem0 Memory Tool supports three different backend configurations:
- Uses FAISS as the local vector store backend
- Requires faiss-cpu package for local vector storage

4. **Neptune Analytics** (Optional Graph backend for search enhancement):
- Uses Neptune Analytics as the graph store backend to enhance memory recall.
4. **Neptune Analytics** (Optional vector store backend):
- Uses Neptune Analytics as the vector store backend.
- Requires AWS credentials and Neptune Analytics configuration
```
# Configure your Neptune Analytics graph ID in the .env file:
Expand Down Expand Up @@ -1180,7 +1180,7 @@ The Mem0 Memory Tool supports three different backend configurations:
- If `MEM0_API_KEY` is set, the tool will use the Mem0 Platform
- If `OPENSEARCH_HOST` is set, the tool will use OpenSearch
- If neither is set, the tool will default to FAISS (requires `faiss-cpu` package)
- If `NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER` is set, the tool will configure Neptune Analytics as graph store to enhance memory search
- If `NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER` is set (and `OPENSEARCH_HOST` is not), the tool will configure Neptune Analytics as the vector store
- LLM configuration applies to all backend modes and allows customization of the language model used for memory processing

#### Bright Data Tool
Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ dev = [
"pytest>=8.0.0,<10.0.0",
"ruff>=0.13.0,<0.14.0",
"responses>=0.6.1,<1.0.0",
"mem0ai>=0.1.104,<1.0.0",
"mem0ai>=2.0.0,<3.0.0",
"opensearch-py>=2.8.0,<3.0.0",
"nest-asyncio>=1.5.0,<2.0.0",
"playwright>=1.42.0,<2.0.0",
Expand All @@ -85,7 +85,7 @@ docs = [
]
mem0-memory = [
# Need to be optional as a fix for https://github.com/strands-agents/docs/issues/19
"mem0ai>=0.1.99,<1.0.0",
"mem0ai>=2.0.0,<3.0.0",
"opensearch-py>=2.8.0,<3.0.0",
]
local-chromium-browser = ["nest-asyncio>=1.5.0,<2.0.0", "playwright>=1.42.0,<2.0.0"]
Expand Down Expand Up @@ -149,6 +149,8 @@ extra-dependencies = [
"pytest-xdist>=3.0.0,<4.0.0",
"responses>=0.6.1,<1.0.0",
"pytest_asyncio>=0.25.0,<2.0.0",
# Local FAISS vector store used by the mem0_memory e2e test (tests_integ).
"faiss-cpu>=1.8.0,<2.0.0",
]
extra-args = ["-n", "auto", '-vv']

Expand Down
158 changes: 20 additions & 138 deletions src/strands_tools/mem0_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,32 +209,16 @@ def _initialize_client(self, config: Optional[Dict] = None) -> Any:
logger.debug("Using FAISS backend (Mem0Memory with FAISS)")
merged_config = self._append_faiss_config(config)

# Graph backend providers

# Graph backend providers
if os.environ.get("NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER") and os.environ.get("NEPTUNE_DATABASE_ENDPOINT"):
raise RuntimeError("""Conflicting backend configurations:
Both NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER and NEPTUNE_DATABASE_ENDPOINT environment variables are set.
Please specify only one graph backend.""")

if os.environ.get("NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER"):
logger.debug("Using Neptune Analytics graph backend (Mem0Memory with Neptune Analytics)")
merged_config = self._append_neptune_analytics_graph_config(merged_config)

elif os.environ.get("NEPTUNE_DATABASE_ENDPOINT"):
logger.debug("Using Neptune Database graph backend (Mem0Memory with Neptune Database)")
merged_config = self._append_neptune_database_backend(merged_config)

return Mem0Memory.from_config(config_dict=merged_config)

def _append_neptune_analytics_vector_config(self, config: Optional[Dict] = None) -> Dict:
"""Update incoming configuration dictionary to include the configuration of Neptune Analytics vector backend.
"""Update incoming configuration dictionary to include the Neptune Analytics vector backend.

Args:
config: Optional configuration dictionary to override defaults.

Returns:
An configuration dict with graph backend.
A configuration dict with Neptune Analytics as the vector store.
"""
config = config or {}
config["vector_store"] = {
Expand All @@ -246,26 +230,6 @@ def _append_neptune_analytics_vector_config(self, config: Optional[Dict] = None)
}
return self._merge_config(config)

def _append_neptune_database_backend(self, config: Optional[Dict] = None) -> Dict:
"""Update incoming configuration dictionary to include the configuration of Neptune Database graph backend.

Args:
config: Optional configuration dictionary to override defaults.

Returns:
An configuration dict with graph backend.
"""
config = config or {}
config["graph_store"] = {
"provider": "neptunedb",
"config": {"endpoint": f"neptune-db://{os.environ.get('NEPTUNE_DATABASE_ENDPOINT')}"},
}
# To retrieve cosine similarity score instead for Faiss.
if "faiss" == config.get("vector_store", {}).get("provider"):
config["vector_store"]["config"]["distance_strategy"] = "cosine"

return config

def _append_opensearch_config(self, config: Optional[Dict] = None) -> Dict:
"""Update incoming configuration dictionary to include the configuration of OpenSearch vector backend.

Expand Down Expand Up @@ -338,21 +302,6 @@ def _append_faiss_config(self, config: Optional[Dict] = None) -> Dict:
}
return merged_config

def _append_neptune_analytics_graph_config(self, config: Dict) -> Dict:
"""Update incoming configuration dictionary to include the configuration of Neptune Analytics graph backend.

Args:
config: Configuration dictionary to add Neptune Analytics graph backend

Returns:
An configuration dict with graph backend.
"""
config["graph_store"] = {
"provider": "neptune",
"config": {"endpoint": f"neptune-graph://{os.environ.get('NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER')}"},
}
return config

def _merge_config(self, config: Optional[Dict] = None) -> Dict:
"""Merge user-provided configuration with default configuration.

Expand Down Expand Up @@ -398,14 +347,30 @@ def list_memories(self, user_id: Optional[str] = None, agent_id: Optional[str] =
if not user_id and not agent_id:
raise ValueError("Either user_id or agent_id must be provided")

return self.mem0.get_all(user_id=user_id, agent_id=agent_id)
# mem0 >=2.0 requires entity ids to be passed via `filters` rather than as
# top-level keyword arguments.
filters = self._build_entity_filters(user_id, agent_id)
return self.mem0.get_all(filters=filters)

def search_memories(self, query: str, user_id: Optional[str] = None, agent_id: Optional[str] = None):
"""Search memories using semantic search."""
if not user_id and not agent_id:
raise ValueError("Either user_id or agent_id must be provided")

return self.mem0.search(query=query, user_id=user_id, agent_id=agent_id)
# mem0 >=2.0 requires entity ids to be passed via `filters` rather than as
# top-level keyword arguments.
filters = self._build_entity_filters(user_id, agent_id)
return self.mem0.search(query=query, filters=filters)

@staticmethod
def _build_entity_filters(user_id: Optional[str], agent_id: Optional[str]) -> Dict[str, str]:
"""Build a mem0 `filters` dict from the provided entity ids."""
filters: Dict[str, str] = {}
if user_id:
filters["user_id"] = user_id
if agent_id:
filters["agent_id"] = agent_id
return filters

def delete_memory(self, memory_id: str):
"""Delete a memory by ID."""
Expand Down Expand Up @@ -520,48 +485,6 @@ def format_retrieve_response(memories: List[Dict]) -> Panel:
return Panel(table, title="[bold green]Search Results", border_style="green")


def format_retrieve_graph_response(memories: List[Dict]) -> Panel:
"""Format retrieve response for graph data"""
if not memories:
return Panel(
"No graph memories found matching the query.", title="[bold yellow]No Matches", border_style="yellow"
)

table = Table(title="Search Results", show_header=True, header_style="bold magenta")
table.add_column("Source", style="cyan", width=25)
table.add_column("Relationship", style="yellow", width=45)
table.add_column("Destination", style="green", width=30)

for memory in memories:
source = memory.get("source", "N/A")
relationship = memory.get("relationship", "N/A")
destination = memory.get("destination", "N/A")

table.add_row(source, relationship, destination)

return Panel(table, title="[bold green]Search Results (Graph)", border_style="green")


def format_list_graph_response(memories: List[Dict]) -> Panel:
"""Format list response for graph data"""
if not memories:
return Panel("No graph memories found.", title="[bold yellow]No Memories", border_style="yellow")

table = Table(title="Graph Memories", show_header=True, header_style="bold magenta")
table.add_column("Source", style="cyan", width=25)
table.add_column("Relationship", style="yellow", width=45)
table.add_column("Target", style="green", width=30)

for memory in memories:
source = memory.get("source", "N/A")
relationship = memory.get("relationship", "N/A")
destination = memory.get("target", "N/A")

table.add_row(source, relationship, destination)

return Panel(table, title="[bold green]Memories List (Graph)", border_style="green")


def format_history_response(history: List[Dict]) -> Panel:
"""Format memory history response."""
if not history:
Expand Down Expand Up @@ -611,26 +534,6 @@ def format_store_response(results: List[Dict]) -> Panel:
return Panel(table, title="[bold green]Memory Stored", border_style="green")


def format_store_graph_response(memories: List[Dict]) -> Panel:
"""Format store response for graph data"""
if not memories:
return Panel("No graph memories stored.", title="[bold yellow]No Memories Stored", border_style="yellow")

table = Table(title="Graph Memories Stored", show_header=True, header_style="bold magenta")
table.add_column("Source", style="cyan", width=25)
table.add_column("Relationship", style="yellow", width=45)
table.add_column("Target", style="green", width=30)

for memory in memories:
source = memory[0].get("source", "N/A")
relationship = memory[0].get("relationship", "N/A")
destination = memory[0].get("target", "N/A")

table.add_row(source, relationship, destination)

return Panel(table, title="[bold green]Memories Stored (Graph)", border_style="green")


def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult:
"""
Memory management tool for storing, retrieving, and managing memories in Mem0.
Expand Down Expand Up @@ -743,13 +646,6 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult:
panel = format_store_response(results_list)
console.print(panel)

# Process graph relations (If any)
if "relations" in results:
relationships_list = results.get("relations").get("added_entities", [])
results_list.extend(relationships_list)
panel_graph = format_store_graph_response(relationships_list)
console.print(panel_graph)

return ToolResult(
toolUseId=tool_use_id,
status="success",
Expand All @@ -774,13 +670,6 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult:
panel = format_list_response(results_list)
console.print(panel)

# Process graph relations (If any)
if "relations" in memories:
relationships_list = memories.get("relations", [])
results_list.extend(relationships_list)
panel_graph = format_list_graph_response(relationships_list)
console.print(panel_graph)

return ToolResult(
toolUseId=tool_use_id,
status="success",
Expand All @@ -801,13 +690,6 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult:
panel = format_retrieve_response(results_list)
console.print(panel)

# Process graph relations (If any)
if "relations" in memories:
relationships_list = memories.get("relations", [])
results_list.extend(relationships_list)
panel_graph = format_retrieve_graph_response(relationships_list)
console.print(panel_graph)

return ToolResult(
toolUseId=tool_use_id,
status="success",
Expand Down
14 changes: 1 addition & 13 deletions tests/test_mem0.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def test_mem0_service_client_init(mock_opensearch, mock_mem0_memory, mock_sessio
with pytest.raises(RuntimeError):
Mem0ServiceClient()

# Test with Neptune Analytics for both vector and graph
# Test with Neptune Analytics as the vector store
with patch.dict(
os.environ,
{
Expand All @@ -446,18 +446,6 @@ def test_mem0_service_client_init(mock_opensearch, mock_mem0_memory, mock_sessio
client = Mem0ServiceClient()
assert client.mem0 is not None

# Test with Neptune Database with OpenSearch
with patch.dict(
os.environ,
{
"OPENSEARCH_HOST": "test.opensearch.amazonaws.com",
"NEPTUNE_DATABASE_ENDPOINT": "xxx.us-west-2.neptune.amazonaws.com",
},
):
client = Mem0ServiceClient()
assert client.region == os.environ.get("AWS_REGION", "us-west-2")
assert client.mem0 is not None

# Test with custom config (OpenSearch)
custom_config = {
"embedder": {"provider": "custom", "config": {"model": "custom-model"}},
Expand Down
Loading
Loading