diff --git a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/MemoryRecord.java b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/MemoryRecord.java index c7e64473..ac01713a 100644 --- a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/MemoryRecord.java +++ b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/MemoryRecord.java @@ -6,7 +6,9 @@ import org.jetbrains.annotations.Nullable; import java.time.Instant; +import java.util.HashMap; import java.util.List; +import java.util.Map; /** * A memory record in the system. @@ -72,6 +74,17 @@ public class MemoryRecord { @JsonProperty("event_date") private Instant eventDate; + @NotNull + @JsonProperty("extraction_strategy") + private String extractionStrategy; + + @NotNull + @JsonProperty("extraction_strategy_config") + private Map extractionStrategyConfig; + + @NotNull + private Map metadata; + public MemoryRecord() { this.id = UlidCreator.getUlid().toString(); Instant now = Instant.now(); @@ -80,6 +93,9 @@ public MemoryRecord() { this.updatedAt = now; this.discreteMemoryExtracted = "f"; this.memoryType = MemoryType.MESSAGE; + this.extractionStrategy = "discrete"; + this.extractionStrategyConfig = new HashMap<>(); + this.metadata = new HashMap<>(); } public MemoryRecord(@NotNull String text) { @@ -233,6 +249,33 @@ public void setEventDate(@Nullable Instant eventDate) { this.eventDate = eventDate; } + @NotNull + public String getExtractionStrategy() { + return extractionStrategy; + } + + public void setExtractionStrategy(@NotNull String extractionStrategy) { + this.extractionStrategy = extractionStrategy; + } + + @NotNull + public Map getExtractionStrategyConfig() { + return extractionStrategyConfig; + } + + public void setExtractionStrategyConfig(@NotNull Map extractionStrategyConfig) { + this.extractionStrategyConfig = extractionStrategyConfig; + } + + @NotNull + public Map getMetadata() { + return metadata; + } + + public void setMetadata(@NotNull Map metadata) { + this.metadata = metadata; + } + @Override public String toString() { return "MemoryRecord{" + @@ -251,6 +294,9 @@ public String toString() { ", memoryType=" + memoryType + ", extractedFrom=" + extractedFrom + ", eventDate=" + eventDate + + ", extractionStrategy='" + extractionStrategy + '\'' + + ", extractionStrategyConfig=" + extractionStrategyConfig + + ", metadata=" + metadata + '}'; } @@ -282,6 +328,9 @@ public static class Builder { private Instant persistedAt; private List extractedFrom; private Instant eventDate; + private String extractionStrategy; + private Map extractionStrategyConfig; + private Map metadata; private Builder() { // Initialize with defaults for extracted memories (client-created long-term memories) @@ -292,6 +341,9 @@ private Builder() { this.updatedAt = now; this.discreteMemoryExtracted = "t"; // "t" for extracted memories this.memoryType = MemoryType.SEMANTIC; // SEMANTIC for long-term memories + this.extractionStrategy = "manual"; + this.extractionStrategyConfig = new HashMap<>(); + this.metadata = new HashMap<>(); } /** @@ -316,6 +368,9 @@ public Builder from(MemoryRecord record) { this.persistedAt = record.persistedAt; this.extractedFrom = record.extractedFrom; this.eventDate = record.eventDate; + this.extractionStrategy = record.extractionStrategy; + this.extractionStrategyConfig = record.extractionStrategyConfig; + this.metadata = record.metadata; return this; } @@ -479,6 +534,36 @@ public Builder eventDate(@Nullable Instant eventDate) { return this; } + /** + * Sets the extraction strategy. + * @param extractionStrategy the extraction strategy + * @return this builder + */ + public Builder extractionStrategy(@NotNull String extractionStrategy) { + this.extractionStrategy = extractionStrategy; + return this; + } + + /** + * Sets the extraction strategy configuration. + * @param extractionStrategyConfig the extraction strategy configuration + * @return this builder + */ + public Builder extractionStrategyConfig(@NotNull Map extractionStrategyConfig) { + this.extractionStrategyConfig = extractionStrategyConfig; + return this; + } + + /** + * Sets additional metadata. + * @param metadata the metadata map + * @return this builder + */ + public Builder metadata(@NotNull Map metadata) { + this.metadata = metadata; + return this; + } + /** * Builds the MemoryRecord instance. * @return a new MemoryRecord @@ -506,6 +591,9 @@ public MemoryRecord build() { record.persistedAt = this.persistedAt; record.extractedFrom = this.extractedFrom; record.eventDate = this.eventDate; + record.extractionStrategy = this.extractionStrategy; + record.extractionStrategyConfig = this.extractionStrategyConfig; + record.metadata = this.metadata; return record; } } diff --git a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/SearchRequest.java b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/SearchRequest.java index 2e0c9ef9..44deb44d 100644 --- a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/SearchRequest.java +++ b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/longtermemory/SearchRequest.java @@ -48,6 +48,10 @@ public class SearchRequest { @JsonProperty("distance_threshold") private Double distanceThreshold; + @Nullable + @JsonProperty("extraction_strategy") + private String extractionStrategy; + private int limit = 10; private int offset = 0; @@ -178,6 +182,15 @@ public void setDistanceThreshold(@Nullable Double distanceThreshold) { this.distanceThreshold = distanceThreshold; } + @Nullable + public String getExtractionStrategy() { + return extractionStrategy; + } + + public void setExtractionStrategy(@Nullable String extractionStrategy) { + this.extractionStrategy = extractionStrategy; + } + public int getLimit() { return limit; } @@ -279,6 +292,7 @@ public String toString() { ", entities=" + entities + ", userId='" + userId + '\'' + ", distanceThreshold=" + distanceThreshold + + ", extractionStrategy='" + extractionStrategy + '\'' + ", limit=" + limit + ", offset=" + offset + ", recencyBoost=" + recencyBoost + @@ -356,6 +370,11 @@ public Builder distanceThreshold(@Nullable Double distanceThreshold) { return this; } + public Builder extractionStrategy(@Nullable String extractionStrategy) { + request.extractionStrategy = extractionStrategy; + return this; + } + public Builder limit(int limit) { request.limit = limit; return this; diff --git a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/summaryview/SummaryViewPartitionResult.java b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/summaryview/SummaryViewPartitionResult.java index 1861c60e..ae4848b3 100644 --- a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/summaryview/SummaryViewPartitionResult.java +++ b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/models/summaryview/SummaryViewPartitionResult.java @@ -24,6 +24,12 @@ public class SummaryViewPartitionResult { @JsonProperty("memory_count") private int memoryCount; + private boolean empty; + + @Nullable + @JsonProperty("empty_reason") + private String emptyReason; + @Nullable @JsonProperty("computed_at") private String computedAt; @@ -77,6 +83,23 @@ public void setMemoryCount(int memoryCount) { this.memoryCount = memoryCount; } + public boolean isEmpty() { + return empty; + } + + public void setEmpty(boolean empty) { + this.empty = empty; + } + + @Nullable + public String getEmptyReason() { + return emptyReason; + } + + public void setEmptyReason(@Nullable String emptyReason) { + this.emptyReason = emptyReason; + } + @Nullable public String getComputedAt() { return computedAt; @@ -91,6 +114,8 @@ public String toString() { return "SummaryViewPartitionResult{" + "viewId='" + viewId + '\'' + ", group=" + group + + ", empty=" + empty + + ", emptyReason='" + emptyReason + '\'' + ", memoryCount=" + memoryCount + '}'; } diff --git a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/services/LongTermMemoryService.java b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/services/LongTermMemoryService.java index faeaf090..b0530148 100644 --- a/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/services/LongTermMemoryService.java +++ b/agent-memory-client/agent-memory-client-java/src/main/java/com/redis/agentmemory/services/LongTermMemoryService.java @@ -106,6 +106,9 @@ public MemoryRecordResults searchLongTermMemories(@NotNull SearchRequest request if (request.getDistanceThreshold() != null) { payload.put("distance_threshold", request.getDistanceThreshold()); } + if (request.getExtractionStrategy() != null) { + payload.put("extraction_strategy", Map.of("eq", request.getExtractionStrategy())); + } // Add recency boost parameters if present if (request.getRecencyBoost() != null) { diff --git a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/models/JsonSerializationTest.java b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/models/JsonSerializationTest.java index f3bf830d..b2f852eb 100644 --- a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/models/JsonSerializationTest.java +++ b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/models/JsonSerializationTest.java @@ -54,17 +54,25 @@ void testMemoryRecordSerialization() throws Exception { record.setMemoryType(MemoryType.SEMANTIC); record.setTopics(Arrays.asList("topic1", "topic2")); record.setEntities(Arrays.asList("entity1", "entity2")); + record.setExtractionStrategy("summary"); + record.setExtractionStrategyConfig(Map.of("summary_version", "v1")); + record.setMetadata(Map.of("message_count", 2)); String json = objectMapper.writeValueAsString(record); assertNotNull(json); assertTrue(json.contains("\"text\":\"Test memory\"")); assertTrue(json.contains("\"user_id\":\"user-123\"")); assertTrue(json.contains("\"memory_type\":\"semantic\"")); + assertTrue(json.contains("\"extraction_strategy\":\"summary\"")); + assertTrue(json.contains("\"metadata\"")); MemoryRecord deserialized = objectMapper.readValue(json, MemoryRecord.class); assertEquals("Test memory", deserialized.getText()); assertEquals("user-123", deserialized.getUserId()); assertEquals(MemoryType.SEMANTIC, deserialized.getMemoryType()); + assertEquals("summary", deserialized.getExtractionStrategy()); + assertEquals("v1", deserialized.getExtractionStrategyConfig().get("summary_version")); + assertEquals(2, deserialized.getMetadata().get("message_count")); assertNotNull(deserialized.getTopics()); assertEquals(2, deserialized.getTopics().size()); } @@ -105,7 +113,6 @@ void testWorkingMemorySerialization() throws Exception { assertNotNull(json); assertTrue(json.contains("\"session_id\":\"session-123\"")); assertTrue(json.contains("\"user_id\":\"user-456\"")); - WorkingMemory deserialized = objectMapper.readValue(json, WorkingMemory.class); assertEquals("session-123", deserialized.getSessionId()); assertEquals("user-456", deserialized.getUserId()); diff --git a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/models/longtermemory/MemoryRecordTest.java b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/models/longtermemory/MemoryRecordTest.java index 84de099a..853a1bf5 100644 --- a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/models/longtermemory/MemoryRecordTest.java +++ b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/models/longtermemory/MemoryRecordTest.java @@ -3,6 +3,7 @@ import org.junit.jupiter.api.Test; import java.util.Arrays; +import java.util.Map; import java.util.List; import static org.junit.jupiter.api.Assertions.*; @@ -19,6 +20,9 @@ void testDefaultConstructor() { assertNotNull(record.getUpdatedAt()); assertEquals("f", record.getDiscreteMemoryExtracted()); assertEquals(MemoryType.MESSAGE, record.getMemoryType()); + assertEquals("discrete", record.getExtractionStrategy()); + assertNotNull(record.getExtractionStrategyConfig()); + assertNotNull(record.getMetadata()); } @Test @@ -45,6 +49,9 @@ void testSettersAndGetters() { record.setEntities(entities); record.setMemoryType(MemoryType.SEMANTIC); + record.setExtractionStrategy("summary"); + record.setExtractionStrategyConfig(Map.of("summary_version", "v1")); + record.setMetadata(Map.of("message_count", 2)); assertEquals("Test memory", record.getText()); assertEquals("session-123", record.getSessionId()); @@ -53,5 +60,8 @@ void testSettersAndGetters() { assertEquals(topics, record.getTopics()); assertEquals(entities, record.getEntities()); assertEquals(MemoryType.SEMANTIC, record.getMemoryType()); + assertEquals("summary", record.getExtractionStrategy()); + assertEquals("v1", record.getExtractionStrategyConfig().get("summary_version")); + assertEquals(2, record.getMetadata().get("message_count")); } } diff --git a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/services/LongTermMemoryServiceTest.java b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/services/LongTermMemoryServiceTest.java index 6ffea49a..9b16e444 100644 --- a/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/services/LongTermMemoryServiceTest.java +++ b/agent-memory-client/agent-memory-client-java/src/test/java/com/redis/agentmemory/services/LongTermMemoryServiceTest.java @@ -461,4 +461,23 @@ void testSearchRequestBuilder_AllRecencyFields() { assertFalse(request.getServerSideRecency()); } + @Test + void testSearchRequestBuilder_ExtractionStrategy() throws Exception { + mockServer.enqueue(new MockResponse() + .setResponseCode(200) + .setHeader("Content-Type", "application/json") + .setBody("{\"memories\":[],\"total\":0}")); + + SearchRequest request = SearchRequest.builder() + .text("query") + .extractionStrategy("summary") + .build(); + + client.longTermMemory().searchLongTermMemories(request); + + RecordedRequest recorded = mockServer.takeRequest(); + String requestBody = recorded.getBody().readUtf8(); + assertTrue(requestBody.contains("\"extraction_strategy\":{\"eq\":\"summary\"}")); + } + } diff --git a/agent-memory-client/agent-memory-client-js/src/client.test.ts b/agent-memory-client/agent-memory-client-js/src/client.test.ts index b22cb537..e1ec0787 100644 --- a/agent-memory-client/agent-memory-client-js/src/client.test.ts +++ b/agent-memory-client/agent-memory-client-js/src/client.test.ts @@ -468,6 +468,7 @@ describe("MemoryAPIClient", () => { expect(result.session_id).toBe("test"); expect(callCount).toBe(2); }); + }); describe("deleteWorkingMemory", () => { @@ -712,6 +713,18 @@ describe("MemoryAPIClient", () => { expect(callBody.memory_type).toEqual({ eq: "episodic" }); }); + it("should handle ExtractionStrategy filter class", async () => { + const { ExtractionStrategy } = await import("./filters"); + mockFetch = createMockFetch({ memories: [], total: 0 }); + client["fetchFn"] = mockFetch; + await client.searchLongTermMemory({ + text: "test", + extractionStrategy: new ExtractionStrategy({ eq: "summary" }), + }); + const callBody = JSON.parse(mockFetch.mock.calls[0][1].body); + expect(callBody.extraction_strategy).toEqual({ eq: "summary" }); + }); + it("should handle EventDate filter class", async () => { const { EventDate } = await import("./filters"); mockFetch = createMockFetch({ memories: [], total: 0 }); @@ -828,6 +841,17 @@ describe("MemoryAPIClient", () => { expect(callBody.memory_type).toEqual({ eq: "episodic" }); }); + it("should handle plain object extractionStrategy filter", async () => { + mockFetch = createMockFetch({ memories: [], total: 0 }); + client["fetchFn"] = mockFetch; + await client.searchLongTermMemory({ + text: "test", + extractionStrategy: { eq: "summary" }, + }); + const callBody = JSON.parse(mockFetch.mock.calls[0][1].body); + expect(callBody.extraction_strategy).toEqual({ eq: "summary" }); + }); + it("should handle plain object eventDate filter", async () => { mockFetch = createMockFetch({ memories: [], total: 0 }); client["fetchFn"] = mockFetch; diff --git a/agent-memory-client/agent-memory-client-js/src/client.ts b/agent-memory-client/agent-memory-client-js/src/client.ts index c25a4799..7ee40270 100644 --- a/agent-memory-client/agent-memory-client-js/src/client.ts +++ b/agent-memory-client/agent-memory-client-js/src/client.ts @@ -20,6 +20,7 @@ import { LastAccessed, EventDate, MemoryType, + ExtractionStrategy, } from "./filters"; import { AckResponse, @@ -82,6 +83,7 @@ export interface SearchOptions { lastAccessed?: LastAccessed | { gte?: Date | string; lte?: Date | string; eq?: Date | string }; userId?: UserId | { eq?: string; in_?: string[]; not_eq?: string; not_in?: string[] }; memoryType?: MemoryType | { eq?: string; in_?: string[]; not_eq?: string; not_in?: string[] }; + extractionStrategy?: ExtractionStrategy | { eq?: string; in_?: string[]; not_eq?: string; not_in?: string[] }; eventDate?: EventDate | { gte?: Date | string; lte?: Date | string; eq?: Date | string }; distanceThreshold?: number; limit?: number; @@ -443,6 +445,12 @@ export class MemoryAPIClient { ? options.memoryType.toJSON() : options.memoryType; } + if (options.extractionStrategy) { + body.extraction_strategy = + options.extractionStrategy instanceof ExtractionStrategy + ? options.extractionStrategy.toJSON() + : options.extractionStrategy; + } if (options.eventDate) { body.event_date = options.eventDate instanceof EventDate diff --git a/agent-memory-client/agent-memory-client-js/src/filters.test.ts b/agent-memory-client/agent-memory-client-js/src/filters.test.ts index 1cddd772..7b13cd3a 100644 --- a/agent-memory-client/agent-memory-client-js/src/filters.test.ts +++ b/agent-memory-client/agent-memory-client-js/src/filters.test.ts @@ -9,6 +9,7 @@ import { LastAccessed, EventDate, MemoryType, + ExtractionStrategy, } from "./filters"; describe("SessionId", () => { @@ -258,3 +259,30 @@ describe("MemoryType", () => { expect(filter.toJSON()).toEqual({}); }); }); + +describe("ExtractionStrategy", () => { + it("should create with eq option", () => { + const filter = new ExtractionStrategy({ eq: "summary" }); + expect(filter.toJSON()).toEqual({ eq: "summary" }); + }); + + it("should create with in_ option", () => { + const filter = new ExtractionStrategy({ in_: ["summary", "discrete"] }); + expect(filter.toJSON()).toEqual({ in_: ["summary", "discrete"] }); + }); + + it("should create with not_eq option", () => { + const filter = new ExtractionStrategy({ not_eq: "manual" }); + expect(filter.toJSON()).toEqual({ not_eq: "manual" }); + }); + + it("should create with not_in option", () => { + const filter = new ExtractionStrategy({ not_in: ["manual", "message"] }); + expect(filter.toJSON()).toEqual({ not_in: ["manual", "message"] }); + }); + + it("should create empty filter", () => { + const filter = new ExtractionStrategy(); + expect(filter.toJSON()).toEqual({}); + }); +}); diff --git a/agent-memory-client/agent-memory-client-js/src/filters.ts b/agent-memory-client/agent-memory-client-js/src/filters.ts index 59c0617f..3adf032c 100644 --- a/agent-memory-client/agent-memory-client-js/src/filters.ts +++ b/agent-memory-client/agent-memory-client-js/src/filters.ts @@ -293,3 +293,34 @@ export class MemoryType { return result; } } + +/** + * Filter by extraction strategy + */ +export class ExtractionStrategy { + eq?: string; + in_?: string[]; + not_eq?: string; + not_in?: string[]; + + constructor(options: { + eq?: string; + in_?: string[]; + not_eq?: string; + not_in?: string[]; + } = {}) { + this.eq = options.eq; + this.in_ = options.in_; + this.not_eq = options.not_eq; + this.not_in = options.not_in; + } + + toJSON(): Record { + const result: Record = {}; + if (this.eq !== undefined) result.eq = this.eq; + if (this.in_ !== undefined) result.in_ = this.in_; + if (this.not_eq !== undefined) result.not_eq = this.not_eq; + if (this.not_in !== undefined) result.not_in = this.not_in; + return result; + } +} diff --git a/agent-memory-client/agent-memory-client-js/src/index.ts b/agent-memory-client/agent-memory-client-js/src/index.ts index 567359d3..1b95100c 100644 --- a/agent-memory-client/agent-memory-client-js/src/index.ts +++ b/agent-memory-client/agent-memory-client-js/src/index.ts @@ -27,6 +27,7 @@ export { LastAccessed, EventDate, MemoryType, + ExtractionStrategy, } from "./filters"; // Export models diff --git a/agent-memory-client/agent-memory-client-js/src/models.test.ts b/agent-memory-client/agent-memory-client-js/src/models.test.ts index c0be969e..d28d7a52 100644 --- a/agent-memory-client/agent-memory-client-js/src/models.test.ts +++ b/agent-memory-client/agent-memory-client-js/src/models.test.ts @@ -1,5 +1,5 @@ import { describe, it, expect } from "vitest"; -import { generateId, MemoryTypeEnum } from "./models"; +import { generateId, MemoryTypeEnum, type MemoryRecord } from "./models"; describe("generateId", () => { it("should generate a ULID string", () => { @@ -22,3 +22,19 @@ describe("MemoryTypeEnum", () => { expect(MemoryTypeEnum.MESSAGE).toBe("message"); }); }); + +describe("MemoryRecord", () => { + it("should allow extraction metadata fields", () => { + const record: MemoryRecord = { + id: "mem-1", + text: "Thread summary", + extraction_strategy: "summary", + extraction_strategy_config: { summary_version: "v1" }, + metadata: { message_count: 2 }, + }; + + expect(record.extraction_strategy).toBe("summary"); + expect(record.extraction_strategy_config).toEqual({ summary_version: "v1" }); + expect(record.metadata).toEqual({ message_count: 2 }); + }); +}); diff --git a/agent-memory-client/agent-memory-client-js/src/models.ts b/agent-memory-client/agent-memory-client-js/src/models.ts index 99423127..9834a27f 100644 --- a/agent-memory-client/agent-memory-client-js/src/models.ts +++ b/agent-memory-client/agent-memory-client-js/src/models.ts @@ -116,6 +116,12 @@ export interface MemoryRecord { extracted_from?: string[] | null; /** Date/time when the event described in this memory occurred */ event_date?: string | null; + /** Memory extraction strategy used when this was promoted from working memory */ + extraction_strategy?: string; + /** Configuration for the extraction strategy used */ + extraction_strategy_config?: Record; + /** Additional non-indexed metadata for provenance and display */ + metadata?: Record; } /** JSON value types for working memory data */ @@ -282,6 +288,7 @@ export interface SearchRequestParams { last_accessed?: LastAccessedFilter | null; user_id?: UserIdFilter | null; memory_type?: MemoryTypeFilter | null; + extraction_strategy?: ExtractionStrategyFilter | null; event_date?: EventDateFilter | null; distance_threshold?: number | null; limit?: number; @@ -356,6 +363,13 @@ export interface MemoryTypeFilter { not_in?: string[] | null; } +export interface ExtractionStrategyFilter { + eq?: string | null; + in_?: string[] | null; + not_eq?: string | null; + not_in?: string[] | null; +} + // ==================== Forget ==================== /** @@ -451,6 +465,10 @@ export interface SummaryViewPartitionResult { summary: string; /** Number of memories that contributed to this summary */ memory_count: number; + /** Whether this partition had no matching memories */ + empty?: boolean; + /** Machine-readable reason for an empty partition */ + empty_reason?: string | null; /** When this summary was computed */ computed_at?: string; } diff --git a/agent-memory-client/agent_memory_client/client.py b/agent-memory-client/agent_memory_client/client.py index ffe028c1..6c4b093a 100644 --- a/agent-memory-client/agent_memory_client/client.py +++ b/agent-memory-client/agent_memory_client/client.py @@ -27,6 +27,7 @@ from .filters import ( CreatedAt, Entities, + ExtractionStrategy, LastAccessed, MemoryType, Namespace, @@ -400,6 +401,8 @@ async def get_or_create_working_memory( memories=[], data={}, user_id=user_id, + long_term_memory_strategy=long_term_memory_strategy + or MemoryStrategyConfig(), ) created_memory = await self.put_working_memory( @@ -1046,6 +1049,7 @@ async def search_long_term_memory( user_id: UserId | dict[str, Any] | None = None, distance_threshold: float | None = None, memory_type: MemoryType | dict[str, Any] | None = None, + extraction_strategy: ExtractionStrategy | dict[str, Any] | None = None, recency: RecencyConfig | None = None, limit: int = 10, offset: int = 0, @@ -1068,6 +1072,7 @@ async def search_long_term_memory( user_id: Optional user ID filter distance_threshold: Optional distance threshold for search results memory_type: Optional memory type filter + extraction_strategy: Optional extraction strategy filter limit: Maximum number of results to return (default: 10) offset: Offset for pagination (default: 0) optimize_query: Whether to optimize the query for semantic (vector) search using a fast model; ignored for keyword and hybrid modes (default: False) @@ -1111,6 +1116,8 @@ async def search_long_term_memory( last_accessed = LastAccessed(**last_accessed) if isinstance(memory_type, dict): memory_type = MemoryType(**memory_type) + if isinstance(extraction_strategy, dict): + extraction_strategy = ExtractionStrategy(**extraction_strategy) # Apply default namespace if needed and no namespace filter specified if namespace is None and self.config.default_namespace is not None: @@ -1146,6 +1153,10 @@ async def search_long_term_memory( payload["user_id"] = user_id.model_dump(exclude_none=True) if memory_type: payload["memory_type"] = memory_type.model_dump(exclude_none=True) + if extraction_strategy: + payload["extraction_strategy"] = extraction_strategy.model_dump( + exclude_none=True + ) if distance_threshold is not None: payload["distance_threshold"] = distance_threshold payload["search_mode"] = ( diff --git a/agent-memory-client/agent_memory_client/filters.py b/agent-memory-client/agent_memory_client/filters.py index 9a81345f..187921b4 100644 --- a/agent-memory-client/agent_memory_client/filters.py +++ b/agent-memory-client/agent_memory_client/filters.py @@ -99,3 +99,12 @@ class MemoryType(BaseFilter): in_: list[str] | None = None not_eq: str | None = None not_in: list[str] | None = None + + +class ExtractionStrategy(BaseFilter): + """Filter by extraction strategy""" + + eq: str | None = None + in_: list[str] | None = None + not_eq: str | None = None + not_in: list[str] | None = None diff --git a/agent-memory-client/agent_memory_client/models.py b/agent-memory-client/agent_memory_client/models.py index f20cea09..c466dec6 100644 --- a/agent-memory-client/agent_memory_client/models.py +++ b/agent-memory-client/agent_memory_client/models.py @@ -18,6 +18,8 @@ logger = logging.getLogger(__name__) +JSONTypes = str | float | int | bool | list[Any] | dict[str, Any] | None + # Model name literals for model-specific window sizes ModelNameLiteral = Literal[ "gpt-3.5-turbo", @@ -244,6 +246,18 @@ class MemoryRecord(BaseModel): default=None, description="Date/time when the event described in this memory occurred (primarily for episodic memories)", ) + extraction_strategy: str = Field( + default="discrete", + description="Memory extraction strategy used when this was promoted from working memory", + ) + extraction_strategy_config: dict[str, Any] = Field( + default_factory=dict, + description="Configuration for the extraction strategy used", + ) + metadata: dict[str, JSONTypes] = Field( + default_factory=dict, + description="Additional non-indexed metadata for provenance and display", + ) @field_validator("topics", "entities", "extracted_from", mode="after") @classmethod @@ -273,9 +287,6 @@ class ClientMemoryRecord(MemoryRecord): ) -JSONTypes = str | float | int | bool | list[Any] | dict[str, Any] - - class WorkingMemory(BaseModel): """Working memory for a session - contains both messages and structured memory records""" @@ -519,10 +530,16 @@ class SummaryViewPartitionResult(BaseModel): group: dict[str, str] = Field( description="Concrete values for the view's group_by fields" ) - summary: str = Field(description="Summarized text for this partition") + summary: str = Field(default="", description="Summarized text for this partition") memory_count: int = Field( description="Number of memories that contributed to this summary" ) + empty: bool = Field( + default=False, description="Whether this partition had no matching memories" + ) + empty_reason: str | None = Field( + default=None, description="Machine-readable reason for an empty partition" + ) computed_at: str | None = Field( default=None, description="When this summary was computed" ) diff --git a/agent_memory_server/api.py b/agent_memory_server/api.py index 5fd9f05d..cefed05f 100644 --- a/agent_memory_server/api.py +++ b/agent_memory_server/api.py @@ -753,7 +753,14 @@ async def search_long_term_memory( try: had_any_strict_filters = any( key in kwargs and kwargs[key] is not None - for key in ("topics", "entities", "namespace", "memory_type", "event_date") + for key in ( + "topics", + "entities", + "namespace", + "memory_type", + "extraction_strategy", + "event_date", + ) ) if ( raw_results.total == 0 @@ -762,7 +769,14 @@ async def search_long_term_memory( == SearchModeEnum.SEMANTIC ): fallback_kwargs = dict(kwargs) - for key in ("topics", "entities", "namespace", "memory_type", "event_date"): + for key in ( + "topics", + "entities", + "namespace", + "memory_type", + "extraction_strategy", + "event_date", + ): fallback_kwargs.pop(key, None) def _vals(f): @@ -781,6 +795,9 @@ def _vals(f): entities_vals = _vals(filters.get("entities")) if filters else [] namespace_vals = _vals(filters.get("namespace")) if filters else [] memory_type_vals = _vals(filters.get("memory_type")) if filters else [] + extraction_strategy_vals = ( + _vals(filters.get("extraction_strategy")) if filters else [] + ) hint_parts: list[str] = [] if topics_vals: @@ -793,6 +810,11 @@ def _vals(f): ) if memory_type_vals: hint_parts.append(f"type: {', '.join(sorted(set(memory_type_vals)))}") + if extraction_strategy_vals: + hint_parts.append( + "extraction strategy: " + + ", ".join(sorted(set(extraction_strategy_vals))) + ) base_text = payload.text or "" hint_suffix = f" ({'; '.join(hint_parts)})" if hint_parts else "" @@ -1152,12 +1174,20 @@ def _validate_summary_view_keys(payload: CreateSummaryViewRequest) -> None: ), ) - allowed_group_by = {"user_id", "namespace", "session_id", "memory_type"} + allowed_group_by = { + "user_id", + "namespace", + "session_id", + "memory_type", + } allowed_filters = { "user_id", "namespace", "session_id", "memory_type", + "extraction_strategy", + "topics", + "event_date", } invalid_group = [k for k in payload.group_by if k not in allowed_group_by] diff --git a/agent_memory_server/filters.py b/agent_memory_server/filters.py index e6721aaf..f38decc6 100644 --- a/agent_memory_server/filters.py +++ b/agent_memory_server/filters.py @@ -256,6 +256,10 @@ def __init__(self, **data): super().__init__(**data) +class ExtractionStrategy(TagFilter): + field: str = "extraction_strategy" + + class EventDate(DateTimeFilter): field: str = "event_date" diff --git a/agent_memory_server/long_term_memory.py b/agent_memory_server/long_term_memory.py index ad2709fa..7aa7b259 100644 --- a/agent_memory_server/long_term_memory.py +++ b/agent_memory_server/long_term_memory.py @@ -1,3 +1,4 @@ +import hashlib import json import logging import numbers @@ -23,6 +24,7 @@ CreatedAt, Entities, EventDate, + ExtractionStrategy, LastAccessed, MemoryHash, MemoryType, @@ -38,6 +40,7 @@ MemoryRecord, MemoryRecordResult, MemoryRecordResults, + MemoryStrategyConfig, MemoryTypeEnum, SearchModeEnum, ) @@ -417,6 +420,240 @@ async def run_delayed_extraction( return 0 +def _build_thread_extraction_context( + working_memory: Any, +) -> tuple[str, dict[str, Any]]: + conversation_messages = [] + source_message_ids = [] + message_context = [] + created_at_values = [] + + for msg in working_memory.messages: + role_prefix = ( + f"[{msg.role.upper()}]: " if hasattr(msg, "role") and msg.role else "" + ) + conversation_messages.append(f"{role_prefix}{msg.content}") + source_message_ids.append(msg.id) + created_at_values.append(msg.created_at) + message_context.append( + { + "id": msg.id, + "role": msg.role, + "created_at": msg.created_at.isoformat(), + } + ) + + full_conversation = "\n".join(conversation_messages) + created_at_min = min(created_at_values) if created_at_values else None + created_at_max = max(created_at_values) if created_at_values else None + + fingerprint_parts = [ + "\x1f".join( + [ + msg.id or "", + msg.created_at.isoformat() if msg.created_at else "", + msg.role or "", + msg.content or "", + ] + ) + for msg in working_memory.messages + ] + source_message_fingerprint = hashlib.sha256( + "\x1e".join(fingerprint_parts).encode("utf-8") + ).hexdigest() + + context: dict[str, Any] = { + "session_id": working_memory.session_id, + "namespace": working_memory.namespace, + "user_id": working_memory.user_id, + "message_count": len(working_memory.messages), + "source_message_ids": source_message_ids, + "messages": message_context, + "created_at_min": created_at_min.isoformat() if created_at_min else None, + "created_at_max": created_at_max.isoformat() if created_at_max else None, + "source_message_fingerprint": source_message_fingerprint, + } + return full_conversation, context + + +def _thread_summary_memory_id( + *, + namespace: str | None, + user_id: str | None, + session_id: str, +) -> str: + material = "|".join([namespace or "", user_id or "", session_id]) + digest = hashlib.sha256(material.encode("utf-8")).hexdigest()[:24] + return f"thread_summary_{digest}" + + +def _parse_optional_extracted_datetime( + memory_data: dict[str, Any], field_name: str, *, session_id: str +) -> datetime | None: + raw_value = memory_data.get(field_name) + if raw_value is None: + return None + if isinstance(raw_value, datetime): + return raw_value + if isinstance(raw_value, str) and raw_value: + try: + return parse_iso8601_datetime(raw_value) + except ValueError: + logger.warning( + "Skipping invalid extracted %s %r for memory %r in session %s", + field_name, + raw_value, + memory_data.get("text"), + session_id, + ) + return None + + +async def _existing_summary_matches_source( + *, + memory_id: str, + summary_version: str, + source_message_fingerprint: str, +) -> bool: + try: + from agent_memory_server.filters import Id + + db = await get_memory_vector_db() + results = await db.list_memories(id=Id(eq=memory_id), limit=1) + except Exception: + logger.exception("Failed checking existing summary memory %s", memory_id) + return False + + if not results.memories: + return False + + metadata = results.memories[0].metadata or {} + return ( + metadata.get("summary_version") == summary_version + and metadata.get("source_message_fingerprint") == source_message_fingerprint + ) + + +def _memory_data_to_record( + *, + memory_data: dict[str, Any], + strategy_config: MemoryStrategyConfig, + thread_context: dict[str, Any], + extraction_time: datetime, +) -> MemoryRecord: + strategy_type = strategy_config.strategy + is_summary = strategy_type == "summary" + summary_version = str( + strategy_config.config.get("summary_version", "thread-summary-v1") + ) + + event_date = _parse_optional_extracted_datetime( + memory_data, "event_date", session_id=str(thread_context["session_id"]) + ) + created_at = ( + _parse_optional_extracted_datetime( + memory_data, "created_at", session_id=str(thread_context["session_id"]) + ) + or extraction_time + ) + + topics = sanitize_tag_values(memory_data.get("topics", [])) + if is_summary: + configured_topics = ( + sanitize_tag_values(strategy_config.config.get("topics", [])) or [] + ) + topics = sorted({*(topics or []), *configured_topics, "thread-summary"}) + + metadata: dict[str, Any] = dict(memory_data.get("metadata") or {}) + if is_summary: + metadata.update( + { + "source_session_id": thread_context.get("session_id"), + "message_count": thread_context.get("message_count"), + "source_message_ids": thread_context.get("source_message_ids", []), + "source_created_at_min": thread_context.get("created_at_min"), + "source_created_at_max": thread_context.get("created_at_max"), + "source_message_fingerprint": thread_context.get( + "source_message_fingerprint" + ), + "summary_version": summary_version, + } + ) + + return MemoryRecord( + id=( + _thread_summary_memory_id( + namespace=thread_context.get("namespace"), + user_id=thread_context.get("user_id"), + session_id=str(thread_context["session_id"]), + ) + if is_summary + else str(ULID()) + ), + text=memory_data["text"], + memory_type=( + MemoryTypeEnum.SEMANTIC + if is_summary + else memory_data.get("type", "semantic") + ), + topics=topics, + entities=sanitize_tag_values(memory_data.get("entities", [])), + event_date=event_date, + session_id=thread_context.get("session_id"), + namespace=thread_context.get("namespace"), + user_id=thread_context.get("user_id"), + created_at=created_at, + extraction_strategy=strategy_type, + extraction_strategy_config=strategy_config.config, + extracted_from=( + thread_context.get("source_message_ids", []) + if is_summary + else memory_data.get("extracted_from") + ), + metadata=metadata, + discrete_memory_extracted="t", + ) + + +def _coalesce_summary_memory_data( + memories_data: list[dict[str, Any]], +) -> list[dict[str, Any]]: + valid_memories = [memory for memory in memories_data if memory.get("text")] + if len(valid_memories) <= 1: + return valid_memories + + coalesced = dict(valid_memories[0]) + coalesced["text"] = "\n\n".join( + str(memory["text"]).strip() + for memory in valid_memories + if str(memory.get("text", "")).strip() + ) + coalesced["type"] = "semantic" + coalesced["topics"] = sorted( + { + topic + for memory in valid_memories + for topic in (sanitize_tag_values(memory.get("topics", [])) or []) + } + ) + coalesced["entities"] = sorted( + { + entity + for memory in valid_memories + for entity in (sanitize_tag_values(memory.get("entities", [])) or []) + } + ) + + metadata: dict[str, Any] = {} + for memory in valid_memories: + if isinstance(memory.get("metadata"), dict): + metadata.update(memory["metadata"]) + if metadata: + coalesced["metadata"] = metadata + + return [coalesced] + + async def extract_memories_from_session_thread( session_id: str, namespace: str | None = None, @@ -447,16 +684,7 @@ async def extract_memories_from_session_thread( logger.info(f"No working memory messages found for session {session_id}") return [] - # Build full conversation context from all messages - conversation_messages = [] - for msg in working_memory.messages: - # Include role and content for better context - role_prefix = ( - f"[{msg.role.upper()}]: " if hasattr(msg, "role") and msg.role else "" - ) - conversation_messages.append(f"{role_prefix}{msg.content}") - - full_conversation = "\n".join(conversation_messages) + full_conversation, thread_context = _build_thread_extraction_context(working_memory) logger.info( f"Extracting memories from {len(working_memory.messages)} messages in session {session_id}" @@ -465,52 +693,65 @@ async def extract_memories_from_session_thread( f"Full conversation context length: {len(full_conversation)} characters" ) - # Use the new memory strategy system for extraction from agent_memory_server.memory_strategies import get_memory_strategy + strategy_config = working_memory.long_term_memory_strategy or MemoryStrategyConfig() try: - # Get the discrete memory strategy for contextual grounding - strategy = get_memory_strategy("discrete") + if strategy_config.strategy == "summary": + summary_version = str( + strategy_config.config.get("summary_version", "thread-summary-v1") + ) + summary_id = _thread_summary_memory_id( + namespace=working_memory.namespace, + user_id=working_memory.user_id, + session_id=working_memory.session_id, + ) + if await _existing_summary_matches_source( + memory_id=summary_id, + summary_version=summary_version, + source_message_fingerprint=str( + thread_context["source_message_fingerprint"] + ), + ): + logger.info( + "Skipping unchanged summary extraction for session %s", + session_id, + ) + return [] - # Extract memories using the strategy - memories_data = await strategy.extract_memories(full_conversation) + strategy = get_memory_strategy( + strategy_config.strategy, + **strategy_config.config, + ) + memories_data = await strategy.extract_memories( + full_conversation, context=thread_context + ) logger.info( - f"Extracted {len(memories_data)} memories from session thread {session_id}" + "Extracted %d memories from session thread %s using strategy %s", + len(memories_data or []), + session_id, + strategy_config.strategy, ) - # Convert to MemoryRecord objects - extracted_memories = [] - for memory_data in memories_data: - event_date = None - event_date_str = memory_data.get("event_date") - if event_date_str: - try: - event_date = parse_iso8601_datetime(event_date_str) - except ValueError: - logger.warning( - "Skipping invalid extracted event_date %r for memory %r in session %s", - event_date_str, - memory_data.get("text"), - session_id, - ) - - memory = MemoryRecord( - id=str(ULID()), - text=memory_data["text"], - memory_type=memory_data.get("type", "semantic"), - topics=sanitize_tag_values(memory_data.get("topics", [])), - entities=sanitize_tag_values(memory_data.get("entities", [])), - event_date=event_date, - session_id=session_id, - namespace=namespace, - user_id=user_id, - discrete_memory_extracted="t", # Mark as extracted + valid_memories_data = [ + memory_data + for memory_data in memories_data or [] + if memory_data.get("text") + ] + if strategy_config.strategy == "summary": + valid_memories_data = _coalesce_summary_memory_data(valid_memories_data) + + extraction_time = datetime.now(UTC) + return [ + _memory_data_to_record( + memory_data=memory_data, + strategy_config=strategy_config, + thread_context=thread_context, + extraction_time=extraction_time, ) - extracted_memories.append(memory) - - return extracted_memories - + for memory_data in valid_memories_data + ] except Exception as e: logger.error(f"Error extracting memories from session thread {session_id}: {e}") return [] @@ -1026,6 +1267,7 @@ async def index_long_term_memories( for memory in valid_memories: current_memory = memory was_deduplicated = False + is_summary_memory = current_memory.extraction_strategy == "summary" # Check for id-based duplicates if not was_deduplicated: @@ -1041,7 +1283,7 @@ async def index_long_term_memories( current_memory = deduped_memory or current_memory # Check for hash-based duplicates - if not was_deduplicated: + if not was_deduplicated and not is_summary_memory: deduped_memory, was_dup = await deduplicate_by_hash( memory=current_memory, redis_client=redis, @@ -1053,7 +1295,11 @@ async def index_long_term_memories( current_memory = deduped_memory or current_memory # Check for semantic duplicates (respects compact_semantic_duplicates setting) - if not was_deduplicated and settings.compact_semantic_duplicates: + if ( + not was_deduplicated + and not is_summary_memory + and settings.compact_semantic_duplicates + ): deduped_memory, was_merged = await deduplicate_by_semantic_search( memory=current_memory, redis_client=redis, @@ -1118,6 +1364,7 @@ async def search_long_term_memories( entities: Entities | None = None, distance_threshold: float | None = None, memory_type: MemoryType | None = None, + extraction_strategy: ExtractionStrategy | None = None, event_date: EventDate | None = None, memory_hash: MemoryHash | None = None, server_side_recency: bool | None = None, @@ -1148,6 +1395,7 @@ async def search_long_term_memories( entities: Optional entities filter distance_threshold: Optional similarity threshold memory_type: Optional memory type filter + extraction_strategy: Optional extraction strategy filter event_date: Optional event date filter memory_hash: Optional memory hash filter limit: Maximum number of results @@ -1175,6 +1423,7 @@ async def search_long_term_memories( topics=topics, entities=entities, memory_type=memory_type, + extraction_strategy=extraction_strategy, event_date=event_date, memory_hash=memory_hash, limit=limit, @@ -1216,6 +1465,7 @@ async def search_long_term_memories( topics=topics, entities=entities, memory_type=memory_type, + extraction_strategy=extraction_strategy, event_date=event_date, memory_hash=memory_hash, distance_threshold=distance_threshold, @@ -1247,6 +1497,7 @@ async def search_long_term_memories( topics=topics, entities=entities, memory_type=memory_type, + extraction_strategy=extraction_strategy, event_date=event_date, memory_hash=memory_hash, distance_threshold=distance_threshold, @@ -1837,6 +2088,7 @@ async def promote_working_memory_to_long_term( # Set extraction strategy configuration from working memory current_memory.extraction_strategy = "message" + current_memory.extraction_strategy_config = {} # Collect memory record for batch indexing message_records_to_index.append(current_memory) diff --git a/agent_memory_server/mcp.py b/agent_memory_server/mcp.py index 5bf55e04..10e7c084 100644 --- a/agent_memory_server/mcp.py +++ b/agent_memory_server/mcp.py @@ -21,6 +21,7 @@ CreatedAt, Entities, EventDate, + ExtractionStrategy, LastAccessed, MemoryType, Namespace, @@ -468,6 +469,7 @@ async def search_long_term_memory( last_accessed: LastAccessed | None = None, user_id: UserId | None = None, memory_type: MemoryType | None = None, + extraction_strategy: ExtractionStrategy | None = None, event_date: EventDate | None = None, distance_threshold: float | None = None, limit: int = 10, @@ -578,6 +580,7 @@ async def search_long_term_memory( last_accessed: Filter by last access date user_id: Filter by user ID memory_type: Filter by memory type + extraction_strategy: Filter by extraction strategy event_date: Filter by event date (for episodic memories) distance_threshold: Distance threshold for semantic search limit: Maximum number of results @@ -617,6 +620,7 @@ async def search_long_term_memory( last_accessed=last_accessed, user_id=user_id, memory_type=memory_type, + extraction_strategy=extraction_strategy, event_date=event_date, distance_threshold=distance_threshold, limit=limit, diff --git a/agent_memory_server/memory_vector_db.py b/agent_memory_server/memory_vector_db.py index f6c5cfdf..1cd55dbb 100644 --- a/agent_memory_server/memory_vector_db.py +++ b/agent_memory_server/memory_vector_db.py @@ -3,6 +3,7 @@ with a RedisVL-based implementation for Redis backends. """ +import json import logging import re from abc import ABC, abstractmethod @@ -27,6 +28,7 @@ DiscreteMemoryExtracted, Entities, EventDate, + ExtractionStrategy, Id, LastAccessed, MemoryHash, @@ -167,6 +169,7 @@ async def search_memories( topics: Topics | None = None, entities: Entities | None = None, memory_type: MemoryType | None = None, + extraction_strategy: ExtractionStrategy | None = None, event_date: EventDate | None = None, memory_hash: MemoryHash | None = None, id: Id | None = None, @@ -198,6 +201,7 @@ async def search_memories( topics: Optional topics filter entities: Optional entities filter memory_type: Optional memory type filter + extraction_strategy: Optional extraction strategy filter event_date: Optional event date filter memory_hash: Optional memory hash filter id: Optional memory ID filter @@ -267,6 +271,7 @@ async def list_memories( topics: Topics | None = None, entities: Entities | None = None, memory_type: MemoryType | None = None, + extraction_strategy: ExtractionStrategy | None = None, event_date: EventDate | None = None, memory_hash: MemoryHash | None = None, id: Id | None = None, @@ -289,6 +294,7 @@ async def list_memories( topics: Optional topics filter entities: Optional entities filter memory_type: Optional memory type filter + extraction_strategy: Optional extraction strategy filter event_date: Optional event date filter memory_hash: Optional memory hash filter id: Optional memory ID filter @@ -393,6 +399,9 @@ class RedisVLMemoryVectorDatabase(MemoryVectorDatabase): "memory_hash", "discrete_memory_extracted", "memory_type", + "extraction_strategy", + "extraction_strategy_config", + "metadata", "persisted_at", "extracted_from", "event_date", @@ -457,6 +466,10 @@ def _memory_to_data(self, memory: MemoryRecord) -> dict[str, Any]: if hasattr(memory.memory_type, "value") else str(memory.memory_type) ) + extraction_strategy_config = json.dumps( + memory.extraction_strategy_config or {}, sort_keys=True + ) + metadata = json.dumps(memory.metadata or {}, sort_keys=True) data: dict[str, Any] = { "text": memory.text, @@ -465,6 +478,9 @@ def _memory_to_data(self, memory: MemoryRecord) -> dict[str, Any]: "user_id": memory.user_id or "", "namespace": memory.namespace or "", "memory_type": memory_type_val, + "extraction_strategy": memory.extraction_strategy or "", + "extraction_strategy_config": extraction_strategy_config, + "metadata": metadata, "topics": topics_str, "entities": entities_str, "memory_hash": memory.memory_hash or "", @@ -533,6 +549,21 @@ def parse_timestamp(val: Any) -> datetime | None: persisted_at = parse_timestamp(fields.get("persisted_at")) event_date = parse_timestamp(fields.get("event_date")) + def parse_json_dict(val: Any) -> dict[str, Any]: + if val is None or val == "": + return {} + if isinstance(val, dict): + return val + if isinstance(val, bytes): + val = val.decode("utf-8") + if isinstance(val, str): + try: + parsed = json.loads(val) + return parsed if isinstance(parsed, dict) else {} + except json.JSONDecodeError: + return {} + return {} + # Provide defaults for required fields if not created_at: created_at = datetime.now(UTC) @@ -559,6 +590,12 @@ def parse_timestamp(val: Any) -> datetime | None: user_id = fields.get("user_id") or None namespace = fields.get("namespace") or None + extraction_strategy = fields.get("extraction_strategy") + if extraction_strategy is None: + extraction_strategy = ( + "message" if fields.get("memory_type") == "message" else "discrete" + ) + return MemoryRecordResult( text=fields.get("text", ""), id=fields.get("id_", ""), @@ -575,6 +612,11 @@ def parse_timestamp(val: Any) -> datetime | None: memory_hash=fields.get("memory_hash", ""), discrete_memory_extracted=fields.get("discrete_memory_extracted", "f"), memory_type=fields.get("memory_type", "message"), + extraction_strategy=extraction_strategy, + extraction_strategy_config=parse_json_dict( + fields.get("extraction_strategy_config") + ), + metadata=parse_json_dict(fields.get("metadata")), persisted_at=persisted_at, extracted_from=self._parse_list_field(fields.get("extracted_from")), event_date=event_date, @@ -791,6 +833,7 @@ async def search_memories( topics: Topics | None = None, entities: Entities | None = None, memory_type: MemoryType | None = None, + extraction_strategy: ExtractionStrategy | None = None, event_date: EventDate | None = None, memory_hash: MemoryHash | None = None, id: Id | None = None, @@ -817,6 +860,7 @@ async def search_memories( user_id=user_id, namespace=namespace, memory_type=memory_type, + extraction_strategy=extraction_strategy, topics=topics, entities=entities, created_at=created_at, @@ -1028,6 +1072,7 @@ async def list_memories( topics: Topics | None = None, entities: Entities | None = None, memory_type: MemoryType | None = None, + extraction_strategy: ExtractionStrategy | None = None, event_date: EventDate | None = None, memory_hash: MemoryHash | None = None, id: Id | None = None, @@ -1049,6 +1094,7 @@ async def list_memories( user_id=user_id, namespace=namespace, memory_type=memory_type, + extraction_strategy=extraction_strategy, topics=topics, entities=entities, created_at=created_at, diff --git a/agent_memory_server/memory_vector_db_factory.py b/agent_memory_server/memory_vector_db_factory.py index 8541730e..9b9d99fa 100644 --- a/agent_memory_server/memory_vector_db_factory.py +++ b/agent_memory_server/memory_vector_db_factory.py @@ -150,6 +150,9 @@ def _build_redis_schema() -> dict: {"name": "entities", "type": "tag", "attrs": {"separator": ","}}, {"name": "memory_hash", "type": "tag"}, {"name": "discrete_memory_extracted", "type": "tag"}, + {"name": "extraction_strategy", "type": "tag"}, + {"name": "extraction_strategy_config", "type": "text"}, + {"name": "metadata", "type": "text"}, {"name": "pinned", "type": "tag"}, {"name": "extracted_from", "type": "tag", "attrs": {"separator": ","}}, {"name": "id_", "type": "tag"}, diff --git a/agent_memory_server/models.py b/agent_memory_server/models.py index 197ae73d..e7c2a505 100644 --- a/agent_memory_server/models.py +++ b/agent_memory_server/models.py @@ -16,6 +16,7 @@ CreatedAt, Entities, EventDate, + ExtractionStrategy, LastAccessed, MemoryType, Namespace, @@ -28,7 +29,7 @@ logger = logging.getLogger(__name__) -JSONTypes = str | float | int | bool | list | dict +JSONTypes = str | float | int | bool | list | dict | None class MemoryTypeEnum(str, Enum): @@ -111,6 +112,7 @@ class MemoryStrategyConfig(BaseModel): def model_dump(self, **kwargs) -> dict[str, Any]: """Override to ensure JSON serialization works properly.""" + kwargs.pop("mode", None) return super().model_dump(mode="json", **kwargs) @@ -330,6 +332,10 @@ class MemoryRecord(BaseModel): default_factory=dict, description="Configuration for the extraction strategy used", ) + metadata: dict[str, JSONTypes] = Field( + default_factory=dict, + description="Additional non-indexed metadata for provenance and display", + ) @field_validator("topics", "entities", "extracted_from", mode="after") @classmethod @@ -422,7 +428,6 @@ def get_create_long_term_memory_tool_description(self) -> str: """ from agent_memory_server.memory_strategies import get_memory_strategy - # Get the configured strategy strategy = get_memory_strategy( self.long_term_memory_strategy.strategy, **self.long_term_memory_strategy.config, @@ -771,6 +776,10 @@ class SearchRequest(BaseModel): default=None, description="Optional memory type to filter by", ) + extraction_strategy: ExtractionStrategy | None = Field( + default=None, + description="Optional extraction strategy to filter by", + ) event_date: EventDate | None = Field( default=None, description="Optional event date to filter by (for episodic memories)", @@ -849,6 +858,9 @@ def get_filters(self): if self.memory_type is not None: filters["memory_type"] = self.memory_type + if self.extraction_strategy is not None: + filters["extraction_strategy"] = self.extraction_strategy + if self.event_date is not None: filters["event_date"] = self.event_date @@ -1072,11 +1084,22 @@ class SummaryViewPartitionResult(BaseModel): group: dict[str, str] = Field( description="Concrete values for the view's group_by fields", ) - summary: str = Field(description="Summarized text for this partition") + summary: str = Field( + default="", + description="Summarized text for this partition", + ) memory_count: int = Field( ge=0, description="Number of memories that contributed to this summary", ) + empty: bool = Field( + default=False, + description="Whether this partition had no matching memories", + ) + empty_reason: str | None = Field( + default=None, + description="Machine-readable reason for an empty partition", + ) computed_at: datetime = Field( default_factory=lambda: datetime.now(UTC), description="When this summary was computed", diff --git a/agent_memory_server/summary_views.py b/agent_memory_server/summary_views.py index 8f3136b0..6adf5f20 100644 --- a/agent_memory_server/summary_views.py +++ b/agent_memory_server/summary_views.py @@ -21,9 +21,12 @@ from agent_memory_server.config import settings from agent_memory_server.filters import ( CreatedAt, + EventDate, + ExtractionStrategy, MemoryType, Namespace, SessionId, + Topics, UserId, ) from agent_memory_server.models import ( @@ -210,6 +213,24 @@ def _build_long_term_filters_for_view( filters: dict[str, Any] = {} + def _tag_filter(filter_cls: type, value: Any) -> Any: + if isinstance(value, filter_cls): + return value + if isinstance(value, str): + return filter_cls(eq=value) + if isinstance(value, list): + return filter_cls(any=[str(item) for item in value]) + if isinstance(value, dict): + return filter_cls(**value) + return filter_cls(eq=str(value)) + + def _datetime_filter(filter_cls: type, value: Any) -> Any: + if isinstance(value, filter_cls): + return value + if isinstance(value, dict): + return filter_cls(**value) + return filter_cls(eq=value) + def _apply_filter(key: str, value: str | Any) -> None: """Apply a single filter mapping from a raw key/value pair. @@ -225,6 +246,12 @@ def _apply_filter(key: str, value: str | Any) -> None: filters["session_id"] = SessionId(eq=str(value)) elif key == "memory_type": filters["memory_type"] = MemoryType(eq=str(value)) + elif key == "extraction_strategy": + filters["extraction_strategy"] = _tag_filter(ExtractionStrategy, value) + elif key == "topics": + filters["topics"] = _tag_filter(Topics, value) + elif key == "event_date": + filters["event_date"] = _datetime_filter(EventDate, value) # Static filters from the view config for key, value in view.filters.items(): @@ -375,10 +402,33 @@ def _build_long_term_summary_prompt( # single extremely long memory cannot dominate the prompt. max_bullet_tokens = min(1024, full_context_tokens // 20) + sorted_memories = sorted( + memories, + key=lambda mem: ( + mem.event_date or mem.created_at, + mem.created_at, + mem.session_id or "", + mem.extraction_strategy or "", + mem.id, + ), + ) + bullet_lines: list[str] = [] - for mem in memories[:_MAX_MEMORIES_FOR_LLM_PROMPT]: + for mem in sorted_memories[:_MAX_MEMORIES_FOR_LLM_PROMPT]: text = mem.text or "" - bullet = f"- {text}" + memory_payload = { + "id": mem.id, + "text": text, + "created_at": mem.created_at.isoformat() if mem.created_at else None, + "event_date": mem.event_date.isoformat() if mem.event_date else None, + "extraction_strategy": mem.extraction_strategy, + "session_id": mem.session_id, + "namespace": mem.namespace, + "user_id": mem.user_id, + "topics": mem.topics or [], + "metadata": mem.metadata or {}, + } + bullet = json.dumps(memory_payload, sort_keys=True, separators=(",", ":")) bullet_tokens = len(encoding.encode(bullet)) if bullet_tokens > max_bullet_tokens: @@ -386,8 +436,8 @@ def _build_long_term_summary_prompt( # recompute tokens. This mirrors the approach used in # agent_memory_server.summarization. approx_chars = max_bullet_tokens * 4 - text = text[:approx_chars] - bullet = f"- {text}" + memory_payload["text"] = text[:approx_chars] + bullet = json.dumps(memory_payload, sort_keys=True, separators=(",", ":")) bullet_tokens = len(encoding.encode(bullet)) if bullet_tokens > remaining_tokens: @@ -420,15 +470,27 @@ async def summarize_partition_long_term( """ if not memories: - summary_text = f"No memories found for group {group!r}." return SummaryViewPartitionResult( view_id=view.id, group=group, - summary=summary_text, + summary="", memory_count=0, + empty=True, + empty_reason="no_matching_memories", computed_at=datetime.now(UTC), ) + memories = sorted( + memories, + key=lambda mem: ( + mem.event_date or mem.created_at, + mem.created_at, + mem.session_id or "", + mem.extraction_strategy or "", + mem.id, + ), + ) + # If no LLM credentials are configured, fall back to a simple # deterministic summary that just concatenates memory texts. if not ( @@ -452,7 +514,9 @@ async def summarize_partition_long_term( default_instructions = ( "You are a summarization assistant. Given a set of long-term " "memories, produce a concise summary that highlights key facts, " - "stable preferences, and important events relevant to the group." + "stable preferences, and important events relevant to the group. " + "Use only supplied memory metadata for dates and timestamps; " + "ignore missing metadata rather than inferring it." ) instructions = view.prompt or default_instructions diff --git a/agent_memory_server/utils/redis_query.py b/agent_memory_server/utils/redis_query.py index 4ddfac3d..ea0c856f 100644 --- a/agent_memory_server/utils/redis_query.py +++ b/agent_memory_server/utils/redis_query.py @@ -32,6 +32,9 @@ class RecencyAggregationQuery(AggregationQuery): "memory_hash", "discrete_memory_extracted", "memory_type", + "extraction_strategy", + "extraction_strategy_config", + "metadata", "persisted_at", "extracted_from", "event_date", diff --git a/docs/api.md b/docs/api.md index 766e839c..32181f16 100644 --- a/docs/api.md +++ b/docs/api.md @@ -589,6 +589,7 @@ A memory record | `event_date` | string \| null | No | Date/time when the event described in this memory occurred ( | | `extraction_strategy` | string | No | Memory extraction strategy used when this was promoted from | | `extraction_strategy_config` | object | No | Configuration for the extraction strategy used | +| `metadata` | object | No | Strategy-specific metadata and provenance | ### CreateMemoryRecordRequest @@ -643,6 +644,7 @@ Payload for long-term memory search | `distance_threshold` | number \| null | No | Optional distance threshold to filter by | | `memory_type` | MemoryType \| null | No | Optional memory type to filter by | | `event_date` | EventDate \| null | No | Optional event date to filter by (for episodic memories) | +| `extraction_strategy` | ExtractionStrategy \| null | No | Optional extraction strategy to filter by | | `limit` | integer | No | Optional limit on the number of results | | `offset` | integer | No | Optional offset | | `recency_boost` | boolean \| null | No | Enable recency-aware re-ranking (defaults to enabled if None | @@ -745,6 +747,8 @@ group_by fields, e.g. {"user_id": "alice"} or | `summary` | string | Yes | Summarized text for this partition | | `memory_count` | integer | Yes | Number of memories that contributed to this summary | | `computed_at` | string | No | When this summary was computed | +| `empty` | boolean | No | True when the partition had no matching memories | +| `empty_reason` | string \| null | No | Machine-readable reason for an empty result, such as `no_matching_memories` | ### Task diff --git a/docs/memory-extraction-strategies.md b/docs/memory-extraction-strategies.md index bf27f320..74db82d7 100644 --- a/docs/memory-extraction-strategies.md +++ b/docs/memory-extraction-strategies.md @@ -7,7 +7,7 @@ This reference documents the configurable extraction strategies that determine h | Strategy | Description | Best For | |----------|-------------|----------| | **Discrete** (default) | Extract individual facts and preferences | General chat, factual information | -| **Summary** | Create conversation summaries | Meeting notes, long conversations | +| **Summary** | Create one durable session/thread summary | Coding-agent sessions, meeting notes, long conversations | | **Preferences** | Focus on user preferences and characteristics | Personalization, user profiles | | **Custom** | Use domain-specific extraction prompts | Technical, legal, medical domains | @@ -69,9 +69,17 @@ working_memory = WorkingMemory( **Configuration Options:** - `max_summary_length`: Maximum characters in summary (default: 500) +- `summary_version`: Optional version marker for deterministic reruns +- `topics`: Optional additional topics to attach to summary memories **Best for:** Long conversations, meeting notes, comprehensive context preservation. +Summary extraction produces a durable semantic memory with +`extraction_strategy="summary"` and topic `thread-summary`. The summary memory +uses a deterministic ID per namespace/user/session so reruns update the same +record. If the source message fingerprint and `summary_version` are unchanged, +rerunning summary extraction is a no-op. + **Example Output:** ```json { @@ -232,6 +240,18 @@ curl -X PUT "http://localhost:8000/v1/working-memory/my-session" \ For more comprehensive integration examples, see [Memory Integration Patterns](memory-integration-patterns.md). +## Deployment Note + +`extraction_strategy` is indexed as a RediSearch tag field. After deploying a +version that adds this field, rebuild the long-term memory index so existing +records are available through extraction-strategy filters: + +```bash +uv run agent-memory rebuild-index +``` + +Run `uv run agent-memory migrate-memories` as usual for data migrations. + ## Best Practices ### 1. Strategy Selection Guidelines diff --git a/docs/summary-views.md b/docs/summary-views.md index 632b21e3..af2ff278 100644 --- a/docs/summary-views.md +++ b/docs/summary-views.md @@ -62,6 +62,25 @@ Create summaries for each conversation session: } ``` +To summarize durable thread summaries created by the `summary` extraction +strategy, filter long-term memory by extraction metadata: + +```json +{ + "name": "coding_agent_thread_summaries", + "source": "long_term", + "group_by": ["namespace", "session_id"], + "filters": { + "memory_type": "semantic", + "extraction_strategy": "summary", + "topics": {"all": ["thread-summary"]} + } +} +``` + +Summary views can filter by `extraction_strategy`, `topics`, and `event_date` +in addition to `user_id`, `namespace`, `session_id`, and `memory_type`. + ## API Endpoints ### Create a Summary View @@ -135,10 +154,16 @@ Response: "group": {"user_id": "alice"}, "summary": "Alice prefers dark mode and uses Python for ML projects...", "memory_count": 42, + "empty": false, + "empty_reason": null, "computed_at": "2024-01-15T10:30:00Z" } ``` +When no memories match a partition, the response is structured instead of a +placeholder summary: `summary` is an empty string, `memory_count` is `0`, +`empty` is `true`, and `empty_reason` is `"no_matching_memories"`. + ### Run All Partitions (Async) Trigger a full background recompute of all partitions: @@ -193,6 +218,9 @@ Both `group_by` and `filters` support: - `namespace` - Partition/filter by namespace - `session_id` - Partition/filter by session - `memory_type` - Partition/filter by type (`semantic`, `episodic`, `message`) +- `extraction_strategy` - Filter by how memories were extracted (`summary`, `discrete`, `preferences`, `custom`) +- `topics` - Filter by tag list, including `{"all": ["thread-summary"]}` +- `event_date` - Filter by structured event date ranges ## Continuous Mode diff --git a/tests/conftest.py b/tests/conftest.py index a33096a3..c16cee3e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -673,6 +673,7 @@ async def search_memories( topics: Any = None, entities: Any = None, memory_type: Any = None, + extraction_strategy: Any = None, event_date: Any = None, memory_hash: Any = None, id: Any = None, @@ -723,6 +724,13 @@ async def search_memories( ) if mem_type_val != memory_type.eq: continue + if ( + extraction_strategy + and hasattr(extraction_strategy, "eq") + and extraction_strategy.eq + and memory.extraction_strategy != extraction_strategy.eq + ): + continue result = MemoryRecordResult( id=memory.id, @@ -741,6 +749,9 @@ async def search_memories( if hasattr(memory.memory_type, "value") else str(memory.memory_type), persisted_at=memory.persisted_at, + extraction_strategy=memory.extraction_strategy, + extraction_strategy_config=memory.extraction_strategy_config, + metadata=memory.metadata, ) results.append(result) @@ -799,6 +810,7 @@ async def list_memories( topics: Any = None, entities: Any = None, memory_type: Any = None, + extraction_strategy: Any = None, event_date: Any = None, memory_hash: Any = None, id: Any = None, @@ -848,6 +860,13 @@ async def list_memories( ) if mem_type_val != memory_type.eq: continue + if ( + extraction_strategy + and hasattr(extraction_strategy, "eq") + and extraction_strategy.eq + and memory.extraction_strategy != extraction_strategy.eq + ): + continue result = MemoryRecordResult( id=memory.id, @@ -866,6 +885,9 @@ async def list_memories( if hasattr(memory.memory_type, "value") else str(memory.memory_type), persisted_at=memory.persisted_at, + extraction_strategy=memory.extraction_strategy, + extraction_strategy_config=memory.extraction_strategy_config, + metadata=memory.metadata, ) results.append(result) diff --git a/tests/test_extraction_logic_fix.py b/tests/test_extraction_logic_fix.py index 9047ee09..01f810f4 100644 --- a/tests/test_extraction_logic_fix.py +++ b/tests/test_extraction_logic_fix.py @@ -12,11 +12,17 @@ import pytest from agent_memory_server.long_term_memory import ( + _coalesce_summary_memory_data, extract_memories_from_session_thread, promote_working_memory_to_long_term, run_delayed_extraction, ) -from agent_memory_server.models import MemoryMessage, MemoryRecord, WorkingMemory +from agent_memory_server.models import ( + MemoryMessage, + MemoryRecord, + MemoryStrategyConfig, + WorkingMemory, +) from agent_memory_server.working_memory import get_working_memory, set_working_memory @@ -237,6 +243,7 @@ async def test_thread_aware_extraction_maps_event_date(self, async_redis_client) ), ], memories=[], + long_term_memory_strategy=MemoryStrategyConfig(strategy="discrete"), ) await set_working_memory(working_memory, redis_client=async_redis_client) @@ -290,6 +297,7 @@ async def test_thread_aware_extraction_skips_non_string_event_date( ), ], memories=[], + long_term_memory_strategy=MemoryStrategyConfig(strategy="discrete"), ) await set_working_memory(working_memory, redis_client=async_redis_client) @@ -448,3 +456,213 @@ async def test_no_extraction_when_debounced(self, async_redis_client): finally: settings.enable_discrete_memory_extraction = original_setting + + @pytest.mark.asyncio + async def test_summary_strategy_produces_first_class_thread_summary( + self, async_redis_client + ): + """Summary extraction should produce one durable semantic thread memory.""" + session_id = "test-summary-thread" + user_id = "test-user" + namespace = "test" + + working_memory = WorkingMemory( + session_id=session_id, + user_id=user_id, + namespace=namespace, + messages=[ + MemoryMessage( + id="msg-1", + role="user", + content="We should use Redis for search.", + created_at=datetime(2026, 5, 1, 12, 0, tzinfo=UTC), + discrete_memory_extracted="f", + ), + MemoryMessage( + id="msg-2", + role="assistant", + content="Redis is a good fit for that.", + created_at=datetime(2026, 5, 1, 12, 1, tzinfo=UTC), + discrete_memory_extracted="f", + ), + ], + memories=[], + long_term_memory_strategy=MemoryStrategyConfig( + strategy="summary", + config={"topics": ["coding-agent"], "summary_version": "v2"}, + ), + ) + await set_working_memory(working_memory, redis_client=async_redis_client) + + mock_strategy = AsyncMock() + mock_strategy.extract_memories.return_value = [ + { + "type": "semantic", + "text": "User and assistant discussed using Redis for search.", + "topics": ["redis"], + "entities": ["User", "Redis"], + } + ] + + with ( + patch( + "agent_memory_server.long_term_memory._existing_summary_matches_source", + return_value=False, + ), + patch( + "agent_memory_server.memory_strategies.get_memory_strategy", + return_value=mock_strategy, + ), + ): + extracted = await extract_memories_from_session_thread( + session_id=session_id, + namespace=namespace, + user_id=user_id, + ) + + assert len(extracted) == 1 + summary = extracted[0] + assert summary.id.startswith("thread_summary_") + assert summary.memory_type == "semantic" + assert summary.extraction_strategy == "summary" + assert summary.event_date is None + assert summary.session_id == session_id + assert summary.namespace == namespace + assert summary.user_id == user_id + assert summary.extracted_from == ["msg-1", "msg-2"] + assert summary.topics == ["coding-agent", "redis", "thread-summary"] + assert summary.metadata["source_session_id"] == session_id + assert summary.metadata["message_count"] == 2 + assert summary.metadata["summary_version"] == "v2" + assert summary.metadata["source_message_ids"] == ["msg-1", "msg-2"] + assert summary.metadata["source_created_at_min"] == ( + "2026-05-01T12:00:00+00:00" + ) + assert summary.metadata["source_created_at_max"] == ( + "2026-05-01T12:01:00+00:00" + ) + assert "source_message_fingerprint" in summary.metadata + + mock_strategy.extract_memories.assert_awaited_once() + call_text = mock_strategy.extract_memories.await_args.args[0] + call_context = mock_strategy.extract_memories.await_args.kwargs["context"] + assert "[USER]: We should use Redis for search." in call_text + assert call_context["source_message_ids"] == ["msg-1", "msg-2"] + + def test_summary_strategy_coalesces_multiple_model_memories(self): + """Summary extraction should produce one record even if the model returns many.""" + coalesced = _coalesce_summary_memory_data( + [ + { + "type": "semantic", + "text": "User chose Redis for search.", + "topics": ["redis"], + "entities": ["Redis"], + }, + { + "type": "semantic", + "text": "User plans to add dashboard filters.", + "topics": ["dashboard"], + "entities": ["User"], + }, + ] + ) + + assert coalesced == [ + { + "type": "semantic", + "text": ( + "User chose Redis for search.\n\n" + "User plans to add dashboard filters." + ), + "topics": ["dashboard", "redis"], + "entities": ["Redis", "User"], + } + ] + + @pytest.mark.asyncio + async def test_summary_strategy_skips_unchanged_existing_summary( + self, async_redis_client + ): + """Unchanged summary extraction should skip the model call.""" + session_id = "test-summary-thread-skip" + working_memory = WorkingMemory( + session_id=session_id, + messages=[ + MemoryMessage( + id="msg-1", + role="user", + content="No changes.", + created_at=datetime(2026, 5, 1, 12, 0, tzinfo=UTC), + discrete_memory_extracted="f", + ) + ], + memories=[], + long_term_memory_strategy=MemoryStrategyConfig(strategy="summary"), + ) + await set_working_memory(working_memory, redis_client=async_redis_client) + + with ( + patch( + "agent_memory_server.long_term_memory._existing_summary_matches_source", + return_value=True, + ) as mock_existing, + patch( + "agent_memory_server.memory_strategies.get_memory_strategy" + ) as mock_get_strategy, + ): + extracted = await extract_memories_from_session_thread( + session_id=session_id + ) + + assert extracted == [] + mock_existing.assert_awaited_once() + mock_get_strategy.assert_not_called() + + @pytest.mark.asyncio + async def test_summary_strategy_changed_thread_keeps_same_id( + self, async_redis_client + ): + """Changed sessions should refresh the same deterministic summary ID.""" + session_id = "test-summary-thread-same-id" + + async def run_once(message_text: str) -> str: + working_memory = WorkingMemory( + session_id=session_id, + messages=[ + MemoryMessage( + id="msg-1", + role="user", + content=message_text, + created_at=datetime(2026, 5, 1, 12, 0, tzinfo=UTC), + discrete_memory_extracted="f", + ) + ], + memories=[], + long_term_memory_strategy=MemoryStrategyConfig(strategy="summary"), + ) + await set_working_memory(working_memory, redis_client=async_redis_client) + + mock_strategy = AsyncMock() + mock_strategy.extract_memories.return_value = [ + {"type": "semantic", "text": f"Summary: {message_text}"} + ] + with ( + patch( + "agent_memory_server.long_term_memory._existing_summary_matches_source", + return_value=False, + ), + patch( + "agent_memory_server.memory_strategies.get_memory_strategy", + return_value=mock_strategy, + ), + ): + extracted = await extract_memories_from_session_thread( + session_id=session_id + ) + return extracted[0].id + + first_id = await run_once("Initial content.") + second_id = await run_once("Changed content.") + + assert first_id == second_id diff --git a/tests/test_long_term_memory.py b/tests/test_long_term_memory.py index 946049df..bac3199f 100644 --- a/tests/test_long_term_memory.py +++ b/tests/test_long_term_memory.py @@ -736,6 +736,69 @@ async def test_index_skips_semantic_dedup_when_disabled( # Semantic dedup should NOT be called when setting is disabled mock_semantic_dedup.assert_not_called() + @pytest.mark.asyncio + async def test_index_skips_hash_and_semantic_dedup_for_summary_memories( + self, mock_async_redis_client + ): + """Summary records must keep their deterministic IDs and provenance.""" + with ( + patch("agent_memory_server.long_term_memory.settings") as mock_settings, + patch( + "agent_memory_server.long_term_memory.get_memory_vector_db" + ) as mock_get_db, + patch( + "agent_memory_server.long_term_memory.deduplicate_by_id" + ) as mock_id_dedup, + patch( + "agent_memory_server.long_term_memory.deduplicate_by_hash" + ) as mock_hash_dedup, + patch( + "agent_memory_server.long_term_memory.deduplicate_by_semantic_search" + ) as mock_semantic_dedup, + patch( + "agent_memory_server.long_term_memory.get_background_tasks" + ) as mock_get_bg, + ): + mock_settings.compact_semantic_duplicates = True + mock_settings.generation_model = "test-model" + mock_settings.long_term_memory_index_name = "memory_records" + mock_settings.llm_task_timeout_minutes = 5 + mock_settings.enable_discrete_memory_extraction = False + + mock_adapter = AsyncMock() + mock_get_db.return_value = mock_adapter + mock_id_dedup.return_value = (None, False) + mock_bg = MagicMock() + mock_get_bg.return_value = mock_bg + + summary = MemoryRecord( + id="thread_summary_abc123", + text="User discussed Redis search.", + namespace="test", + memory_type=MemoryTypeEnum.SEMANTIC, + extraction_strategy="summary", + metadata={ + "source_message_fingerprint": "fp", + "summary_version": "thread-summary-v1", + }, + ) + + await index_long_term_memories( + [summary], + redis_client=mock_async_redis_client, + deduplicate=True, + ) + + mock_id_dedup.assert_awaited_once() + mock_hash_dedup.assert_not_called() + mock_semantic_dedup.assert_not_called() + mock_adapter.add_memories.assert_awaited_once() + indexed_memories = mock_adapter.add_memories.await_args.args[0] + assert indexed_memories == [summary] + assert indexed_memories[0].id == "thread_summary_abc123" + assert indexed_memories[0].extraction_strategy == "summary" + assert indexed_memories[0].metadata["source_message_fingerprint"] == "fp" + @pytest.mark.asyncio async def test_promote_working_memory_to_long_term(self, mock_async_redis_client): """Test promoting memories from working memory to long-term storage""" diff --git a/tests/test_memory_vector_db.py b/tests/test_memory_vector_db.py index 51b95ee9..6733e105 100644 --- a/tests/test_memory_vector_db.py +++ b/tests/test_memory_vector_db.py @@ -4,6 +4,7 @@ import pytest +from agent_memory_server.filters import ExtractionStrategy from agent_memory_server.memory_vector_db import ( MemoryVectorDatabase, PhraseAwareAggregateHybridQuery, @@ -119,6 +120,9 @@ def test_memory_to_data_conversion(self): topics=["testing", "memory"], entities=["test"], memory_type=MemoryTypeEnum.SEMANTIC, + extraction_strategy="summary", + extraction_strategy_config={"summary_version": "v1"}, + metadata={"source_session_id": "session-456", "message_count": 2}, ) data = db._memory_to_data(memory) @@ -131,6 +135,9 @@ def test_memory_to_data_conversion(self): assert data["topics"] == "testing,memory" assert data["entities"] == "test" assert data["memory_type"] == "semantic" + assert data["extraction_strategy"] == "summary" + assert '"summary_version": "v1"' in data["extraction_strategy_config"] + assert '"message_count": 2' in data["metadata"] def test_data_to_memory_result_conversion(self): """Test converting data dict to MemoryRecordResult.""" @@ -147,6 +154,9 @@ def test_data_to_memory_result_conversion(self): "topics": "testing,memory", "entities": "test", "memory_type": "semantic", + "extraction_strategy": "summary", + "extraction_strategy_config": '{"summary_version":"v1"}', + "metadata": '{"source_session_id":"session-456","message_count":2}', "created_at": "1704067200", # 2024-01-01T00:00:00Z "last_accessed": "1704067200", "updated_at": "1704067200", @@ -163,9 +173,78 @@ def test_data_to_memory_result_conversion(self): assert result.topics == ["testing", "memory"] assert result.entities == ["test"] assert result.memory_type == "semantic" + assert result.extraction_strategy == "summary" + assert result.extraction_strategy_config == {"summary_version": "v1"} + assert result.metadata == { + "source_session_id": "session-456", + "message_count": 2, + } assert result.dist == 0.2 assert result.discrete_memory_extracted == "t" + def test_data_to_memory_result_preserves_empty_extraction_strategy(self): + """Explicit empty extraction_strategy should not trigger legacy fallback.""" + mock_index = MagicMock() + mock_embeddings = MockEmbeddings() + db = RedisVLMemoryVectorDatabase(mock_index, mock_embeddings) + + result = db._data_to_memory_result( + { + "id_": "test-123", + "text": "This is a test memory", + "memory_type": "semantic", + "extraction_strategy": "", + } + ) + + assert result.extraction_strategy == "" + + def test_data_to_memory_result_defaults_missing_extraction_strategy(self): + """Missing extraction_strategy should retain legacy read defaults.""" + mock_index = MagicMock() + mock_embeddings = MockEmbeddings() + db = RedisVLMemoryVectorDatabase(mock_index, mock_embeddings) + + semantic_result = db._data_to_memory_result( + { + "id_": "semantic-123", + "text": "This is a semantic memory", + "memory_type": "semantic", + } + ) + message_result = db._data_to_memory_result( + { + "id_": "message-123", + "text": "This is a message memory", + "memory_type": "message", + } + ) + + assert semantic_result.extraction_strategy == "discrete" + assert message_result.extraction_strategy == "message" + + def test_redis_schema_includes_extraction_metadata_fields(self): + """RedisVL schema should expose extraction strategy fields.""" + schema = _build_redis_schema() + fields = {field["name"]: field for field in schema["fields"]} + + assert fields["extraction_strategy"]["type"] == "tag" + assert fields["extraction_strategy_config"]["type"] == "text" + assert fields["metadata"]["type"] == "text" + + def test_build_filter_expression_supports_extraction_strategy(self): + """The RedisVL filter expression should include extraction_strategy.""" + mock_index = MagicMock() + mock_embeddings = MockEmbeddings() + db = RedisVLMemoryVectorDatabase(mock_index, mock_embeddings) + + expression = db._build_filter_expression( + extraction_strategy=ExtractionStrategy(eq="summary") + ) + + assert "extraction_strategy" in str(expression) + assert "summary" in str(expression) + @pytest.mark.asyncio async def test_add_memories_with_mock_index(self): """Test adding memories to a mock index.""" diff --git a/tests/test_summary_views.py b/tests/test_summary_views.py index 41acc0bb..a0621f50 100644 --- a/tests/test_summary_views.py +++ b/tests/test_summary_views.py @@ -3,6 +3,8 @@ These tests verify that all examples from docs/summary-views.md work correctly. """ +from datetime import UTC, datetime + import pytest from agent_memory_server.models import MemoryRecord, SummaryView, TaskStatusEnum @@ -628,6 +630,118 @@ async def test_combined_filters(self, client): assert view["filters"] == {"namespace": "test_ns", "memory_type": "semantic"} +def test_build_long_term_filters_supports_extraction_topics_and_event_date(): + """Summary views should support extraction strategy, topic, and event-date filters.""" + from agent_memory_server.summary_views import _build_long_term_filters_for_view + + view = SummaryView( + id="view-filters", + source="long_term", + group_by=["session_id"], + filters={ + "extraction_strategy": "summary", + "topics": {"all": ["thread-summary"]}, + "event_date": {"gte": datetime(2026, 5, 1, tzinfo=UTC)}, + }, + ) + + filters = _build_long_term_filters_for_view(view) + + assert filters["extraction_strategy"].eq == "summary" + assert filters["topics"].all == ["thread-summary"] + assert filters["event_date"].gte == datetime(2026, 5, 1, tzinfo=UTC) + + +def test_build_long_term_summary_prompt_includes_metadata_and_sorts(): + """Summary prompts should include memory metadata in deterministic order.""" + from agent_memory_server.summary_views import _build_long_term_summary_prompt + + view = SummaryView(id="view-prompt", source="long_term", group_by=["session_id"]) + newer = MemoryRecord( + id="mem-newer", + text="Newer memory", + session_id="s2", + namespace="test", + user_id="u1", + created_at=datetime(2026, 5, 2, 12, 0, tzinfo=UTC), + extraction_strategy="discrete", + topics=["preference"], + metadata={"source_session_id": "s2"}, + ) + older = MemoryRecord( + id="mem-older", + text="Older summary", + session_id="s1", + namespace="test", + user_id="u1", + created_at=datetime(2026, 5, 1, 12, 0, tzinfo=UTC), + extraction_strategy="summary", + topics=["thread-summary"], + metadata={"message_count": 3, "summary_version": "v1"}, + ) + + prompt = _build_long_term_summary_prompt( + view=view, + group={"session_id": "s1"}, + memories=[newer, older], + model_name="gpt-4o-mini", + instructions="Summarize these memories.", + ) + + assert prompt.index("mem-older") < prompt.index("mem-newer") + assert '"extraction_strategy":"summary"' in prompt + assert '"session_id":"s1"' in prompt + assert '"metadata":{"message_count":3,"summary_version":"v1"}' in prompt + assert "Current date" not in prompt + assert "current date" not in prompt.lower() + + +@pytest.mark.asyncio +async def test_summarize_partition_long_term_returns_structured_empty_result(): + """Empty summary partitions should be machine-readable, not text placeholders.""" + from agent_memory_server.summary_views import summarize_partition_long_term + + view = SummaryView(id="view-empty", source="long_term", group_by=["user_id"]) + + result = await summarize_partition_long_term(view, {"user_id": "missing"}, []) + + assert result.summary == "" + assert result.memory_count == 0 + assert result.empty is True + assert result.empty_reason == "no_matching_memories" + + +@pytest.mark.asyncio +async def test_summarize_partition_long_term_fallback_sorts_memories(monkeypatch): + """Fallback summaries should use the same deterministic memory ordering.""" + from agent_memory_server.config import settings + from agent_memory_server.summary_views import summarize_partition_long_term + + monkeypatch.setattr(settings, "openai_api_key", None) + monkeypatch.setattr(settings, "anthropic_api_key", None) + monkeypatch.setattr(settings, "aws_access_key_id", None) + + view = SummaryView(id="view-sort", source="long_term", group_by=["session_id"]) + newer = MemoryRecord( + id="mem-newer", + text="Newer memory", + created_at=datetime(2026, 5, 2, 12, 0, tzinfo=UTC), + extraction_strategy="discrete", + ) + older = MemoryRecord( + id="mem-older", + text="Older memory", + created_at=datetime(2026, 5, 1, 12, 0, tzinfo=UTC), + extraction_strategy="summary", + ) + + result = await summarize_partition_long_term( + view, {"session_id": "s1"}, [newer, older] + ) + + assert result.summary.index("Older memory") < result.summary.index("Newer memory") + + class TestConfigurationOptionsTable: """Tests verifying configuration options documented in the SummaryView Fields table.""" @@ -1082,7 +1196,7 @@ class TestSummarizePartitionLongTerm: @pytest.mark.asyncio async def test_returns_no_memories_message_for_empty_list(self): - """summarize_partition_long_term should return message when no memories.""" + """summarize_partition_long_term should return structured empty result.""" from agent_memory_server.models import SummaryView from agent_memory_server.summary_views import summarize_partition_long_term @@ -1095,7 +1209,9 @@ async def test_returns_no_memories_message_for_empty_list(self): ) assert result.memory_count == 0 - assert "No memories found" in result.summary + assert result.summary == "" + assert result.empty is True + assert result.empty_reason == "no_matching_memories" @pytest.mark.asyncio async def test_fallback_when_no_api_keys(self, monkeypatch):