diff --git a/dotnet/src/Microsoft.Agents.AI.Purview/BackgroundJobRunner.cs b/dotnet/src/Microsoft.Agents.AI.Purview/BackgroundJobRunner.cs index 85a4fa54c3..03f73d8007 100644 --- a/dotnet/src/Microsoft.Agents.AI.Purview/BackgroundJobRunner.cs +++ b/dotnet/src/Microsoft.Agents.AI.Purview/BackgroundJobRunner.cs @@ -1,10 +1,14 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Generic; using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; +using Microsoft.Agents.AI.Purview.Models.Common; using Microsoft.Agents.AI.Purview.Models.Jobs; +using Microsoft.Agents.AI.Purview.Models.Requests; +using Microsoft.Agents.AI.Purview.Models.Responses; using Microsoft.Extensions.Logging; namespace Microsoft.Agents.AI.Purview; @@ -16,6 +20,7 @@ internal sealed class BackgroundJobRunner : IBackgroundJobRunner { private readonly IChannelHandler _channelHandler; private readonly IPurviewClient _purviewClient; + private readonly ICacheProvider _cacheProvider; private readonly ILogger _logger; /// @@ -23,12 +28,14 @@ internal sealed class BackgroundJobRunner : IBackgroundJobRunner /// /// The channel handler used to manage job channels. /// The Purview client used to send requests to Purview. + /// The cache provider used to store protection scopes results. /// The logger used to log information about background jobs. /// The settings used to configure Purview client behavior. - public BackgroundJobRunner(IChannelHandler channelHandler, IPurviewClient purviewClient, ILogger logger, PurviewSettings purviewSettings) + public BackgroundJobRunner(IChannelHandler channelHandler, IPurviewClient purviewClient, ICacheProvider cacheProvider, ILogger logger, PurviewSettings purviewSettings) { this._channelHandler = channelHandler; this._purviewClient = purviewClient; + this._cacheProvider = cacheProvider; this._logger = logger; for (int i = 0; i < purviewSettings.MaxConcurrentJobConsumers; i++) @@ -67,6 +74,28 @@ private async Task RunJobAsync(BackgroundJobBase job) break; case ContentActivityJob contentActivityJob: _ = await this._purviewClient.SendContentActivitiesAsync(contentActivityJob.Request, CancellationToken.None).ConfigureAwait(false); + break; + case ScopeRetrievalJob scopeRetrievalJob: + try + { + ProtectionScopesResponse response = await this._purviewClient.GetProtectionScopesAsync(scopeRetrievalJob.Request, CancellationToken.None).ConfigureAwait(false); + await this._cacheProvider.SetAsync(scopeRetrievalJob.CacheKey, response, CancellationToken.None).ConfigureAwait(false); + (bool shouldProcess, List _, ExecutionMode _) = ScopedContentProcessor.CheckApplicableScopes(scopeRetrievalJob.ProcessContentRequest, response); + if (!shouldProcess) + { + ProcessContentRequest pcRequest = scopeRetrievalJob.ProcessContentRequest; + ContentActivitiesRequest caRequest = new(pcRequest.UserId, pcRequest.TenantId, pcRequest.ContentToProcess, pcRequest.CorrelationId); + this._channelHandler.QueueJob(new ContentActivityJob(caRequest)); + } + } + catch (PurviewPaymentRequiredException ex) + { + await this._cacheProvider.SetAsync( + new PaymentRequiredCacheKey(scopeRetrievalJob.Request.TenantId), + new PaymentRequiredCacheEntry(ex.Message), + CancellationToken.None).ConfigureAwait(false); + } + break; } } diff --git a/dotnet/src/Microsoft.Agents.AI.Purview/Models/Common/PaymentRequiredCacheEntry.cs b/dotnet/src/Microsoft.Agents.AI.Purview/Models/Common/PaymentRequiredCacheEntry.cs new file mode 100644 index 0000000000..6bd9d40853 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Purview/Models/Common/PaymentRequiredCacheEntry.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.Agents.AI.Purview.Models.Common; + +/// +/// Cached tenant-level payment required state. +/// +internal sealed class PaymentRequiredCacheEntry +{ + /// + /// Creates a new instance of . + /// + /// The payment required error message. + public PaymentRequiredCacheEntry(string? message) + { + this.Message = message; + } + + /// + /// The payment required error message. + /// + public string? Message { get; set; } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Purview/Models/Common/PaymentRequiredCacheKey.cs b/dotnet/src/Microsoft.Agents.AI.Purview/Models/Common/PaymentRequiredCacheKey.cs new file mode 100644 index 0000000000..3c9ad4f813 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Purview/Models/Common/PaymentRequiredCacheKey.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.Agents.AI.Purview.Models.Common; + +/// +/// A cache key for tenant-level payment required state. +/// +internal sealed class PaymentRequiredCacheKey +{ + /// + /// Creates a new instance of . + /// + /// The id of the tenant. + public PaymentRequiredCacheKey(string tenantId) + { + this.TenantId = tenantId; + } + + /// + /// The id of the tenant. + /// + public string TenantId { get; set; } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Purview/Models/Jobs/ScopeRetrievalJob.cs b/dotnet/src/Microsoft.Agents.AI.Purview/Models/Jobs/ScopeRetrievalJob.cs new file mode 100644 index 0000000000..c23553f185 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Purview/Models/Jobs/ScopeRetrievalJob.cs @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Agents.AI.Purview.Models.Common; +using Microsoft.Agents.AI.Purview.Models.Requests; + +namespace Microsoft.Agents.AI.Purview.Models.Jobs; + +/// +/// Class representing a job that refreshes the protection scopes cache in the background. +/// +/// +/// Used by the parallel protection scopes retrieval path to warm the cache without blocking the +/// foreground ProcessContent call. +/// +internal sealed class ScopeRetrievalJob : BackgroundJobBase +{ + /// + /// Initializes a new instance of the class. + /// + /// The protection scopes request to send to Purview. + /// The cache key used to store the response. + /// The original process content request that triggered scope retrieval. + public ScopeRetrievalJob(ProtectionScopesRequest request, ProtectionScopesCacheKey cacheKey, ProcessContentRequest processContentRequest) + { + this.Request = request; + this.CacheKey = cacheKey; + this.ProcessContentRequest = processContentRequest; + } + + /// + /// Gets the protection scopes request. + /// + public ProtectionScopesRequest Request { get; } + + /// + /// Gets the cache key used to store the response. + /// + public ProtectionScopesCacheKey CacheKey { get; } + + /// + /// Gets the original process content request that triggered scope retrieval. + /// + public ProcessContentRequest ProcessContentRequest { get; } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Purview/Models/Requests/ProcessContentRequest.cs b/dotnet/src/Microsoft.Agents.AI.Purview/Models/Requests/ProcessContentRequest.cs index f8e9602cef..d41a3a2090 100644 --- a/dotnet/src/Microsoft.Agents.AI.Purview/Models/Requests/ProcessContentRequest.cs +++ b/dotnet/src/Microsoft.Agents.AI.Purview/Models/Requests/ProcessContentRequest.cs @@ -53,4 +53,10 @@ public ProcessContentRequest(ContentToProcess contentToProcess, string userId, s /// [JsonIgnore] internal string? ScopeIdentifier { get; set; } + + /// + /// Indicates whether the ProcessContent request should ask the service for inline evaluation. + /// + [JsonIgnore] + internal bool ProcessInline { get; set; } } diff --git a/dotnet/src/Microsoft.Agents.AI.Purview/PurviewClient.cs b/dotnet/src/Microsoft.Agents.AI.Purview/PurviewClient.cs index 28013f524e..43b564b58f 100644 --- a/dotnet/src/Microsoft.Agents.AI.Purview/PurviewClient.cs +++ b/dotnet/src/Microsoft.Agents.AI.Purview/PurviewClient.cs @@ -130,6 +130,11 @@ public async Task ProcessContentAsync(ProcessContentRequ message.Headers.Add("If-None-Match", request.ScopeIdentifier); } + if (request.ProcessInline) + { + message.Headers.Add("Prefer", "evaluateInline"); + } + string content = JsonSerializer.Serialize(request, PurviewSerializationUtils.SerializationSettings.GetTypeInfo(typeof(ProcessContentRequest))); message.Content = new StringContent(content, Encoding.UTF8, "application/json"); diff --git a/dotnet/src/Microsoft.Agents.AI.Purview/README.md b/dotnet/src/Microsoft.Agents.AI.Purview/README.md index 1a9fc70725..bcd1a26192 100644 --- a/dotnet/src/Microsoft.Agents.AI.Purview/README.md +++ b/dotnet/src/Microsoft.Agents.AI.Purview/README.md @@ -218,8 +218,8 @@ The policy logic is identical; the only difference is the hook point in the pipe The user id from the prompt message(s) is reused for the response evaluation so both evaluations map consistently to the same user. -There are several optimizations to speed up Purview calls. Protection scope lookups (the first step in evaluation) are cached to minimize network calls. -If the policies allow content to be processed offline, the middleware will add the process content request to a channel and run it in a background worker. Similarly, the middleware will run a background request if no scopes apply and the interaction only has to be logged in Audit. +There are several optimizations to speed up Purview calls. Protection scope lookups (the first step in evaluation) are cached to minimize network calls. When a lookup is not cached, the middleware will refresh it in a background worker so the foreground ProcessContent request does not have to wait. +If the policies allow content to be processed offline, the middleware will add the process content request to a channel and run it in a background worker. Similarly, the middleware will run a background request if no scopes apply and the interaction only has to be logged in Audit. Payment Required responses from background scope lookups are cached at the tenant level so subsequent requests for the tenant short-circuit. ## Exceptions | Exception | Scenario | diff --git a/dotnet/src/Microsoft.Agents.AI.Purview/ScopedContentProcessor.cs b/dotnet/src/Microsoft.Agents.AI.Purview/ScopedContentProcessor.cs index 3fb7aa6c4d..3e280014a0 100644 --- a/dotnet/src/Microsoft.Agents.AI.Purview/ScopedContentProcessor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Purview/ScopedContentProcessor.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Threading; using System.Threading.Tasks; using Microsoft.Agents.AI.Purview.Models.Common; @@ -193,43 +194,60 @@ private async Task ProcessContentWithProtectionScopesAsy { ProtectionScopesRequest psRequest = CreateProtectionScopesRequest(pcRequest, pcRequest.UserId, pcRequest.TenantId, pcRequest.CorrelationId); + PaymentRequiredCacheEntry? cachedPaymentRequired = await this._cacheProvider.GetAsync( + new PaymentRequiredCacheKey(pcRequest.TenantId), + cancellationToken).ConfigureAwait(false); + + if (cachedPaymentRequired != null) + { + throw new PurviewPaymentRequiredException(cachedPaymentRequired.Message ?? "Payment required"); + } + ProtectionScopesCacheKey cacheKey = new(psRequest); ProtectionScopesResponse? cacheResponse = await this._cacheProvider.GetAsync(cacheKey, cancellationToken).ConfigureAwait(false); - ProtectionScopesResponse psResponse; - if (cacheResponse != null) { - psResponse = cacheResponse; + return await this.ProcessWithCachedScopesAsync(pcRequest, cacheResponse, cacheKey, cancellationToken).ConfigureAwait(false); } - else + + try { - psResponse = await this._purviewClient.GetProtectionScopesAsync(psRequest, cancellationToken).ConfigureAwait(false); - await this._cacheProvider.SetAsync(cacheKey, psResponse, cancellationToken).ConfigureAwait(false); + this._channelHandler.QueueJob(new ScopeRetrievalJob(psRequest, cacheKey, pcRequest)); + } + catch (PurviewJobException) + { + // QueueJob already logs failures. Scope warmup is best effort; don't block ProcessContent. } + return await this.CallProcessContentAsync(pcRequest, cacheKey, dlpActions: null, cancellationToken).ConfigureAwait(false); + } + + /// + /// Apply locally-cached protection scopes to the request and dispatch ProcessContent appropriately. + /// + private async Task ProcessWithCachedScopesAsync( + ProcessContentRequest pcRequest, + ProtectionScopesResponse psResponse, + ProtectionScopesCacheKey cacheKey, + CancellationToken cancellationToken) + { pcRequest.ScopeIdentifier = psResponse.ScopeIdentifier; (bool shouldProcess, List dlpActions, ExecutionMode executionMode) = CheckApplicableScopes(pcRequest, psResponse); if (shouldProcess) { + pcRequest.ProcessInline = executionMode == ExecutionMode.EvaluateInline; + if (executionMode == ExecutionMode.EvaluateOffline) { this._channelHandler.QueueJob(new ProcessContentJob(pcRequest)); return new ProcessContentResponse(); } - ProcessContentResponse pcResponse = await this._purviewClient.ProcessContentAsync(pcRequest, cancellationToken).ConfigureAwait(false); - - if (pcResponse.ProtectionScopeState == ProtectionScopeState.Modified) - { - await this._cacheProvider.RemoveAsync(cacheKey, cancellationToken).ConfigureAwait(false); - } - - pcResponse = CombinePolicyActions(pcResponse, dlpActions); - return pcResponse; + return await this.CallProcessContentAsync(pcRequest, cacheKey, dlpActions, cancellationToken).ConfigureAwait(false); } ContentActivitiesRequest caRequest = new(pcRequest.UserId, pcRequest.TenantId, pcRequest.ContentToProcess, pcRequest.CorrelationId); @@ -238,6 +256,30 @@ private async Task ProcessContentWithProtectionScopesAsy return new ProcessContentResponse(); } + /// + /// Call ProcessContent and invalidate the protection scopes cache when the response indicates the cached scopes are stale. + /// + private async Task CallProcessContentAsync( + ProcessContentRequest pcRequest, + ProtectionScopesCacheKey cacheKey, + List? dlpActions, + CancellationToken cancellationToken) + { + ProcessContentResponse pcResponse = await this._purviewClient.ProcessContentAsync(pcRequest, cancellationToken).ConfigureAwait(false); + + if (pcRequest.ScopeIdentifier != null && pcResponse.ProtectionScopeState == ProtectionScopeState.Modified) + { + await this._cacheProvider.RemoveAsync(cacheKey, cancellationToken).ConfigureAwait(false); + } + + if (dlpActions?.Count > 0) + { + pcResponse = CombinePolicyActions(pcResponse, dlpActions); + } + + return pcResponse; + } + /// /// Dedupe policy actions received from the service. /// @@ -248,9 +290,21 @@ private static ProcessContentResponse CombinePolicyActions(ProcessContentRespons { if (actionInfos?.Count > 0) { - pcResponse.PolicyActions = pcResponse.PolicyActions is null ? - actionInfos : - [.. pcResponse.PolicyActions, .. actionInfos]; + List combinedActions = []; + HashSet<(DlpAction Action, RestrictionAction? RestrictionAction)> seenActions = []; + IEnumerable allActions = pcResponse.PolicyActions is null + ? actionInfos + : pcResponse.PolicyActions.Concat(actionInfos); + + foreach (DlpActionInfo actionInfo in allActions) + { + if (seenActions.Add((actionInfo.Action, actionInfo.RestrictionAction))) + { + combinedActions.Add(actionInfo); + } + } + + pcResponse.PolicyActions = combinedActions; } return pcResponse; @@ -262,7 +316,7 @@ private static ProcessContentResponse CombinePolicyActions(ProcessContentRespons /// The process content request. /// The protection scopes response that was returned for the process content request. /// A bool indicating if the content needs to be processed. A list of applicable actions from the scopes response, and the execution mode for the process content request. - private static (bool shouldProcess, List dlpActions, ExecutionMode executionMode) CheckApplicableScopes(ProcessContentRequest pcRequest, ProtectionScopesResponse psResponse) + internal static (bool shouldProcess, List dlpActions, ExecutionMode executionMode) CheckApplicableScopes(ProcessContentRequest pcRequest, ProtectionScopesResponse psResponse) { ProtectionScopeActivities requestActivity = TranslateActivity(pcRequest.ContentToProcess.ActivityMetadata.Activity); @@ -284,7 +338,11 @@ private static (bool shouldProcess, List dlpActions, ExecutionMod foreach (var location in scope.Locations ?? Array.Empty()) { - locationMatch = location.DataType.EndsWith(locationType, StringComparison.OrdinalIgnoreCase) && location.Value.Equals(locationValue, StringComparison.OrdinalIgnoreCase); + if (location.DataType.EndsWith(locationType, StringComparison.OrdinalIgnoreCase) && location.Value.Equals(locationValue, StringComparison.OrdinalIgnoreCase)) + { + locationMatch = true; + break; + } } if (activityMatch && locationMatch) diff --git a/dotnet/src/Microsoft.Agents.AI.Purview/Serialization/PurviewSerializationUtils.cs b/dotnet/src/Microsoft.Agents.AI.Purview/Serialization/PurviewSerializationUtils.cs index 320fbcd3b6..0be4c59267 100644 --- a/dotnet/src/Microsoft.Agents.AI.Purview/Serialization/PurviewSerializationUtils.cs +++ b/dotnet/src/Microsoft.Agents.AI.Purview/Serialization/PurviewSerializationUtils.cs @@ -18,6 +18,8 @@ namespace Microsoft.Agents.AI.Purview.Serialization; [JsonSerializable(typeof(ContentActivitiesRequest))] [JsonSerializable(typeof(ContentActivitiesResponse))] [JsonSerializable(typeof(ProtectionScopesCacheKey))] +[JsonSerializable(typeof(PaymentRequiredCacheKey))] +[JsonSerializable(typeof(PaymentRequiredCacheEntry))] internal sealed partial class SourceGenerationContext : JsonSerializerContext; /// diff --git a/dotnet/tests/Microsoft.Agents.AI.Purview.UnitTests/PurviewClientTests.cs b/dotnet/tests/Microsoft.Agents.AI.Purview.UnitTests/PurviewClientTests.cs index 38abc903d3..6b857101c7 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Purview.UnitTests/PurviewClientTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Purview.UnitTests/PurviewClientTests.cs @@ -115,6 +115,24 @@ public async Task ProcessContentAsync_WithScopeIdentifier_IncludesIfNoneMatchHea Assert.Equal("\"test-scope-123\"", this._handler.IfNoneMatchHeader); } + [Fact] + public async Task ProcessContentAsync_WithProcessInline_IncludesPreferHeaderAsync() + { + // Arrange + var request = CreateValidProcessContentRequest(); + request.ProcessInline = true; + var expectedResponse = new ProcessContentResponse { Id = "test-id" }; + + this._handler.StatusCodeToReturn = HttpStatusCode.OK; + this._handler.ResponseToReturn = JsonSerializer.Serialize(expectedResponse, PurviewSerializationUtils.SerializationSettings.GetTypeInfo(typeof(ProcessContentResponse))); + + // Act + await this._client.ProcessContentAsync(request, CancellationToken.None); + + // Assert + Assert.Equal("evaluateInline", this._handler.PreferHeader); + } + [Fact] public async Task ProcessContentAsync_WithRateLimitError_ThrowsPurviewRateLimitExceptionAsync() { @@ -530,6 +548,7 @@ internal sealed class PurviewClientHttpMessageHandlerStub : HttpMessageHandler public HttpMethod? RequestMethod { get; private set; } public string? AuthorizationHeader { get; private set; } public string? IfNoneMatchHeader { get; private set; } + public string? PreferHeader { get; private set; } protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) { @@ -547,6 +566,11 @@ protected override async Task SendAsync(HttpRequestMessage this.IfNoneMatchHeader = string.Join(", ", ifNoneMatchValues); } + if (request.Headers.TryGetValues("Prefer", out var preferValues)) + { + this.PreferHeader = string.Join(", ", preferValues); + } + // Throw HttpRequestException if configured if (this.ShouldThrowHttpRequestException) { diff --git a/dotnet/tests/Microsoft.Agents.AI.Purview.UnitTests/ScopedContentProcessorTests.cs b/dotnet/tests/Microsoft.Agents.AI.Purview.UnitTests/ScopedContentProcessorTests.cs index 3527cc9884..3cfc81face 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Purview.UnitTests/ScopedContentProcessorTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Purview.UnitTests/ScopedContentProcessorTests.cs @@ -3,12 +3,14 @@ using System; using System.Collections.Generic; using System.Threading; +using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.Agents.AI.Purview.Models.Common; using Microsoft.Agents.AI.Purview.Models.Jobs; using Microsoft.Agents.AI.Purview.Models.Requests; using Microsoft.Agents.AI.Purview.Models.Responses; using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging.Abstractions; using Moq; namespace Microsoft.Agents.AI.Purview.UnitTests; @@ -50,10 +52,6 @@ public async Task ProcessMessagesAsync_WithBlockAccessAction_ReturnsShouldBlockT this._mockPurviewClient.Setup(x => x.GetUserInfoFromTokenAsync(It.IsAny(), null)) .ReturnsAsync(tokenInfo); - this._mockCacheProvider.Setup(x => x.GetAsync( - It.IsAny(), It.IsAny())) - .ReturnsAsync((ProtectionScopesResponse?)null); - var psResponse = new ProtectionScopesResponse { Scopes = @@ -70,8 +68,8 @@ public async Task ProcessMessagesAsync_WithBlockAccessAction_ReturnsShouldBlockT ] }; - this._mockPurviewClient.Setup(x => x.GetProtectionScopesAsync( - It.IsAny(), It.IsAny())) + this._mockCacheProvider.Setup(x => x.GetAsync( + It.IsAny(), It.IsAny())) .ReturnsAsync(psResponse); var pcResponse = new ProcessContentResponse @@ -109,10 +107,6 @@ public async Task ProcessMessagesAsync_WithRestrictionActionBlock_ReturnsShouldB this._mockPurviewClient.Setup(x => x.GetUserInfoFromTokenAsync(It.IsAny(), null)) .ReturnsAsync(tokenInfo); - this._mockCacheProvider.Setup(x => x.GetAsync( - It.IsAny(), It.IsAny())) - .ReturnsAsync((ProtectionScopesResponse?)null); - var psResponse = new ProtectionScopesResponse { Scopes = @@ -129,8 +123,8 @@ public async Task ProcessMessagesAsync_WithRestrictionActionBlock_ReturnsShouldB ] }; - this._mockPurviewClient.Setup(x => x.GetProtectionScopesAsync( - It.IsAny(), It.IsAny())) + this._mockCacheProvider.Setup(x => x.GetAsync( + It.IsAny(), It.IsAny())) .ReturnsAsync(psResponse); var pcResponse = new ProcessContentResponse @@ -168,10 +162,6 @@ public async Task ProcessMessagesAsync_WithNoBlockingActions_ReturnsShouldBlockF this._mockPurviewClient.Setup(x => x.GetUserInfoFromTokenAsync(It.IsAny(), null)) .ReturnsAsync(tokenInfo); - this._mockCacheProvider.Setup(x => x.GetAsync( - It.IsAny(), It.IsAny())) - .ReturnsAsync((ProtectionScopesResponse?)null); - var psResponse = new ProtectionScopesResponse { Scopes = @@ -188,8 +178,8 @@ public async Task ProcessMessagesAsync_WithNoBlockingActions_ReturnsShouldBlockF ] }; - this._mockPurviewClient.Setup(x => x.GetProtectionScopesAsync( - It.IsAny(), It.IsAny())) + this._mockCacheProvider.Setup(x => x.GetAsync( + It.IsAny(), It.IsAny())) .ReturnsAsync(psResponse); var pcResponse = new ProcessContentResponse @@ -213,6 +203,99 @@ public async Task ProcessMessagesAsync_WithNoBlockingActions_ReturnsShouldBlockF Assert.Equal("user-123", result.userId); } + [Fact] + public async Task ProcessMessagesAsync_DeduplicatesCombinedPolicyActionsByActionAndRestrictionAsync() + { + // Arrange + List messages = + [ + new(ChatRole.User, "Test message") + ]; + PurviewSettings settings = CreateValidPurviewSettings(); + TokenInfo tokenInfo = new() { TenantId = "tenant-123", UserId = "user-123", ClientId = "client-123" }; + DlpActionInfo processContentAction = new() { Action = DlpAction.BlockAccess, RestrictionAction = RestrictionAction.Block }; + DlpActionInfo duplicateScopeAction = new() { Action = DlpAction.BlockAccess, RestrictionAction = RestrictionAction.Block }; + DlpActionInfo restrictionOnlyAction = new() { RestrictionAction = RestrictionAction.Block }; + ProcessContentResponse pcResponse = new() + { + PolicyActions = + [ + processContentAction + ] + }; + ProtectionScopesResponse psResponse = new() + { + Scopes = + [ + new() + { + Activities = ProtectionScopeActivities.UploadText, + Locations = + [ + new("microsoft.graph.policyLocationApplication", "app-123") + ], + ExecutionMode = ExecutionMode.EvaluateInline, + PolicyActions = + [ + duplicateScopeAction, + restrictionOnlyAction + ] + } + ] + }; + + this._mockPurviewClient.Setup(x => x.GetUserInfoFromTokenAsync(It.IsAny(), null)) + .ReturnsAsync(tokenInfo); + + this._mockCacheProvider.Setup(x => x.GetAsync( + It.IsAny(), It.IsAny())) + .ReturnsAsync(psResponse); + + this._mockPurviewClient.Setup(x => x.ProcessContentAsync( + It.IsAny(), It.IsAny())) + .ReturnsAsync(pcResponse); + + // Act + await this._processor.ProcessMessagesAsync( + messages, "session-123", Activity.UploadText, settings, "user-123", CancellationToken.None); + + // Assert + Assert.NotNull(pcResponse.PolicyActions); + Assert.Equal(2, pcResponse.PolicyActions.Count); + Assert.Same(processContentAction, pcResponse.PolicyActions[0]); + Assert.Same(restrictionOnlyAction, pcResponse.PolicyActions[1]); + } + + [Fact] + public void CheckApplicableScopes_MatchesAnyLocationInScope() + { + // Arrange + ProcessContentRequest pcRequest = CreateProcessContentRequest(); + ProtectionScopesResponse psResponse = new() + { + Scopes = + [ + new() + { + Activities = ProtectionScopeActivities.UploadText, + Locations = + [ + new("microsoft.graph.policyLocationApplication", "app-123"), + new("microsoft.graph.policyLocationApplication", "different-app") + ], + ExecutionMode = ExecutionMode.EvaluateInline + } + ] + }; + + // Act + (bool shouldProcess, _, ExecutionMode executionMode) = ScopedContentProcessor.CheckApplicableScopes(pcRequest, psResponse); + + // Assert + Assert.True(shouldProcess); + Assert.Equal(ExecutionMode.EvaluateInline, executionMode); + } + [Fact] public async Task ProcessMessagesAsync_UsesCachedProtectionScopes_WhenAvailableAsync() { @@ -279,12 +362,9 @@ public async Task ProcessMessagesAsync_InvalidatesCache_WhenProtectionScopeModif this._mockPurviewClient.Setup(x => x.GetUserInfoFromTokenAsync(It.IsAny(), null)) .ReturnsAsync(tokenInfo); - this._mockCacheProvider.Setup(x => x.GetAsync( - It.IsAny(), It.IsAny())) - .ReturnsAsync((ProtectionScopesResponse?)null); - var psResponse = new ProtectionScopesResponse { + ScopeIdentifier = "etag-1", Scopes = [ new() @@ -299,8 +379,8 @@ public async Task ProcessMessagesAsync_InvalidatesCache_WhenProtectionScopeModif ] }; - this._mockPurviewClient.Setup(x => x.GetProtectionScopesAsync( - It.IsAny(), It.IsAny())) + this._mockCacheProvider.Setup(x => x.GetAsync( + It.IsAny(), It.IsAny())) .ReturnsAsync(psResponse); var pcResponse = new ProcessContentResponse @@ -336,10 +416,6 @@ public async Task ProcessMessagesAsync_SendsContentActivities_WhenNoApplicableSc this._mockPurviewClient.Setup(x => x.GetUserInfoFromTokenAsync(It.IsAny(), null)) .ReturnsAsync(tokenInfo); - this._mockCacheProvider.Setup(x => x.GetAsync( - It.IsAny(), It.IsAny())) - .ReturnsAsync((ProtectionScopesResponse?)null); - var psResponse = new ProtectionScopesResponse { Scopes = @@ -355,8 +431,8 @@ public async Task ProcessMessagesAsync_SendsContentActivities_WhenNoApplicableSc ] }; - this._mockPurviewClient.Setup(x => x.GetProtectionScopesAsync( - It.IsAny(), It.IsAny())) + this._mockCacheProvider.Setup(x => x.GetAsync( + It.IsAny(), It.IsAny())) .ReturnsAsync(psResponse); // Act @@ -432,13 +508,9 @@ public async Task ProcessMessagesAsync_ExtractsUserIdFromMessageAdditionalProper this._mockPurviewClient.Setup(x => x.GetUserInfoFromTokenAsync(It.IsAny(), null)) .ReturnsAsync(tokenInfo); + var psResponse = new ProtectionScopesResponse { Scopes = [] }; this._mockCacheProvider.Setup(x => x.GetAsync( It.IsAny(), It.IsAny())) - .ReturnsAsync((ProtectionScopesResponse?)null); - - var psResponse = new ProtectionScopesResponse { Scopes = [] }; - this._mockPurviewClient.Setup(x => x.GetProtectionScopesAsync( - It.IsAny(), It.IsAny())) .ReturnsAsync(psResponse); // Act @@ -467,13 +539,9 @@ public async Task ProcessMessagesAsync_ExtractsUserIdFromMessageAuthorName_WhenV this._mockPurviewClient.Setup(x => x.GetUserInfoFromTokenAsync(It.IsAny(), null)) .ReturnsAsync(tokenInfo); + var psResponse = new ProtectionScopesResponse { Scopes = [] }; this._mockCacheProvider.Setup(x => x.GetAsync( It.IsAny(), It.IsAny())) - .ReturnsAsync((ProtectionScopesResponse?)null); - - var psResponse = new ProtectionScopesResponse { Scopes = [] }; - this._mockPurviewClient.Setup(x => x.GetProtectionScopesAsync( - It.IsAny(), It.IsAny())) .ReturnsAsync(psResponse); // Act @@ -484,10 +552,260 @@ public async Task ProcessMessagesAsync_ExtractsUserIdFromMessageAuthorName_WhenV Assert.Equal(userId, result.userId); } + [Fact] + public async Task ProcessMessagesAsync_CacheMiss_QueuesScopeRetrievalJobAndCallsProcessContentAsync() + { + // Arrange + var messages = new List + { + new (ChatRole.User, "Test message") + }; + var settings = CreateValidPurviewSettings(); + var tokenInfo = new TokenInfo { TenantId = "tenant-123", UserId = "user-123", ClientId = "client-123" }; + this._mockPurviewClient.Setup(x => x.GetUserInfoFromTokenAsync(It.IsAny(), null)) + .ReturnsAsync(tokenInfo); + + this._mockCacheProvider.Setup(x => x.GetAsync( + It.IsAny(), It.IsAny())) + .ReturnsAsync((ProtectionScopesResponse?)null); + + this._mockPurviewClient.Setup(x => x.ProcessContentAsync( + It.IsAny(), It.IsAny())) + .ReturnsAsync(new ProcessContentResponse()); + + // Act + await this._processor.ProcessMessagesAsync( + messages, "session-123", Activity.UploadText, settings, "user-123", CancellationToken.None); + + // Assert: ProcessContent runs in the foreground; GetProtectionScopes is queued as a background job. + this._mockPurviewClient.Verify(x => x.ProcessContentAsync( + It.IsAny(), It.IsAny()), Times.Once); + this._mockPurviewClient.Verify(x => x.GetProtectionScopesAsync( + It.IsAny(), It.IsAny()), Times.Never); + this._mockChannelHandler.Verify(x => x.QueueJob(It.IsAny()), Times.Once); + } + + [Fact] + public async Task ProcessMessagesAsync_CacheMiss_WithProcessContentBlockAction_ReturnsShouldBlockTrueAsync() + { + // Arrange + var messages = new List + { + new (ChatRole.User, "Test message") + }; + var settings = CreateValidPurviewSettings(); + var tokenInfo = new TokenInfo { TenantId = "tenant-123", UserId = "user-123", ClientId = "client-123" }; + this._mockPurviewClient.Setup(x => x.GetUserInfoFromTokenAsync(It.IsAny(), null)) + .ReturnsAsync(tokenInfo); + + this._mockCacheProvider.Setup(x => x.GetAsync( + It.IsAny(), It.IsAny())) + .ReturnsAsync((ProtectionScopesResponse?)null); + + var pcResponse = new ProcessContentResponse + { + PolicyActions = + [ + new() { Action = DlpAction.BlockAccess } + ] + }; + + this._mockPurviewClient.Setup(x => x.ProcessContentAsync( + It.IsAny(), It.IsAny())) + .ReturnsAsync(pcResponse); + + // Act + var result = await this._processor.ProcessMessagesAsync( + messages, "session-123", Activity.UploadText, settings, "user-123", CancellationToken.None); + + // Assert + Assert.True(result.shouldBlock); + this._mockChannelHandler.Verify(x => x.QueueJob(It.IsAny()), Times.Once); + } + + [Fact] + public async Task ProcessMessagesAsync_CacheMiss_StillCallsProcessContentWhenScopeJobCannotQueueAsync() + { + // Arrange + var messages = new List + { + new (ChatRole.User, "Test message") + }; + var settings = CreateValidPurviewSettings(); + var tokenInfo = new TokenInfo { TenantId = "tenant-123", UserId = "user-123", ClientId = "client-123" }; + this._mockPurviewClient.Setup(x => x.GetUserInfoFromTokenAsync(It.IsAny(), null)) + .ReturnsAsync(tokenInfo); + + this._mockCacheProvider.Setup(x => x.GetAsync( + It.IsAny(), It.IsAny())) + .ReturnsAsync((ProtectionScopesResponse?)null); + + this._mockChannelHandler.Setup(x => x.QueueJob(It.IsAny())) + .Throws(new PurviewJobException("queue unavailable")); + + this._mockPurviewClient.Setup(x => x.ProcessContentAsync( + It.IsAny(), It.IsAny())) + .ReturnsAsync(new ProcessContentResponse()); + + // Act + await this._processor.ProcessMessagesAsync( + messages, "session-123", Activity.UploadText, settings, "user-123", CancellationToken.None); + + // Assert: scope warmup is attempted, and ProcessContent still runs when it can't be queued. + this._mockChannelHandler.Verify(x => x.QueueJob(It.IsAny()), Times.Once); + this._mockPurviewClient.Verify(x => x.ProcessContentAsync( + It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task ProcessMessagesAsync_WithCachedPaymentRequiredState_ThrowsPaymentRequiredAsync() + { + // Arrange + var messages = new List + { + new (ChatRole.User, "Test message") + }; + var settings = CreateValidPurviewSettings(); + var tokenInfo = new TokenInfo { TenantId = "tenant-123", UserId = "user-123", ClientId = "client-123" }; + this._mockPurviewClient.Setup(x => x.GetUserInfoFromTokenAsync(It.IsAny(), null)) + .ReturnsAsync(tokenInfo); + + this._mockCacheProvider.Setup(x => x.GetAsync( + It.IsAny(), It.IsAny())) + .ReturnsAsync(new PaymentRequiredCacheEntry("Payment required")); + + // Act + Assert + await Assert.ThrowsAsync(() => + this._processor.ProcessMessagesAsync( + messages, "session-123", Activity.UploadText, settings, "user-123", CancellationToken.None)); + + this._mockPurviewClient.Verify(x => x.ProcessContentAsync( + It.IsAny(), It.IsAny()), Times.Never); + this._mockChannelHandler.Verify(x => x.QueueJob(It.IsAny()), Times.Never); + } + + [Fact] + public async Task BackgroundJobRunner_ScopeRetrievalPaymentRequired_CachesForSubsequentCallsAsync() + { + // Arrange + Func, Task>? runner = null; + Mock channelHandler = new(); + Mock purviewClient = new(); + Mock cacheProvider = new(); + PurviewSettings settings = new("TestApp") { MaxConcurrentJobConsumers = 1 }; + ProtectionScopesRequest request = new("user-123", "tenant-123") + { + Activities = ProtectionScopeActivities.UploadText, + Locations = + [ + new("microsoft.graph.policyLocationApplication", "app-123") + ] + }; + ProtectionScopesCacheKey cacheKey = new(request); + Channel channel = Channel.CreateUnbounded(); + + channelHandler.Setup(x => x.AddRunner(It.IsAny, Task>>())) + .Callback, Task>>(callback => runner = callback); + + purviewClient.Setup(x => x.GetProtectionScopesAsync(It.IsAny(), It.IsAny())) + .ThrowsAsync(new PurviewPaymentRequiredException("Payment required")); + + _ = new BackgroundJobRunner(channelHandler.Object, purviewClient.Object, cacheProvider.Object, NullLogger.Instance, settings); + + // Act + Assert.NotNull(runner); + await channel.Writer.WriteAsync(new ScopeRetrievalJob(request, cacheKey, CreateProcessContentRequest())); + channel.Writer.Complete(); + await runner(channel); + + // Assert + cacheProvider.Verify(x => x.SetAsync( + It.Is(key => key.TenantId == "tenant-123"), + It.Is(entry => entry.Message == "Payment required"), + It.IsAny()), Times.Once); + } + + [Fact] + public async Task BackgroundJobRunner_ScopeRetrievalNoApplicableScopes_QueuesContentActivityJobAsync() + { + // Arrange + Func, Task>? runner = null; + Mock channelHandler = new(); + Mock purviewClient = new(); + Mock cacheProvider = new(); + PurviewSettings settings = new("TestApp") { MaxConcurrentJobConsumers = 1 }; + ProtectionScopesRequest request = CreateProtectionScopesRequest(); + ScopeRetrievalJob job = new(request, new ProtectionScopesCacheKey(request), CreateProcessContentRequest()); + Channel channel = Channel.CreateUnbounded(); + + channelHandler.Setup(x => x.AddRunner(It.IsAny, Task>>())) + .Callback, Task>>(callback => runner = callback); + + purviewClient.Setup(x => x.GetProtectionScopesAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new ProtectionScopesResponse { Scopes = [] }); + + _ = new BackgroundJobRunner(channelHandler.Object, purviewClient.Object, cacheProvider.Object, NullLogger.Instance, settings); + + // Act + Assert.NotNull(runner); + await channel.Writer.WriteAsync(job); + channel.Writer.Complete(); + await runner(channel); + + // Assert + channelHandler.Verify(x => x.QueueJob(It.IsAny()), Times.Once); + } + #endregion #region Helper Methods + private static ProtectionScopesRequest CreateProtectionScopesRequest() + { + return new ProtectionScopesRequest("user-123", "tenant-123") + { + Activities = ProtectionScopeActivities.UploadText, + Locations = + [ + new("microsoft.graph.policyLocationApplication", "app-123") + ] + }; + } + + private static ProcessContentRequest CreateProcessContentRequest() + { + PurviewTextContent content = new("Test content"); + ProcessConversationMetadata metadata = new(content, "msg-123", false, "Test message", "test-correlation-id"); + ActivityMetadata activityMetadata = new(Activity.UploadText); + DeviceMetadata deviceMetadata = new() + { + OperatingSystemSpecifications = new() + { + OperatingSystemPlatform = "Windows", + OperatingSystemVersion = "10" + } + }; + IntegratedAppMetadata integratedAppMetadata = new() + { + Name = "TestApp", + Version = "1.0" + }; + PolicyLocation policyLocation = new("microsoft.graph.policyLocationApplication", "app-123"); + ProtectedAppMetadata protectedAppMetadata = new(policyLocation) + { + Name = "TestApp", + Version = "1.0" + }; + ContentToProcess contentToProcess = new( + [metadata], + activityMetadata, + deviceMetadata, + integratedAppMetadata, + protectedAppMetadata); + + return new ProcessContentRequest(contentToProcess, "user-123", "tenant-123"); + } + private static PurviewSettings CreateValidPurviewSettings() { return new PurviewSettings("TestApp") diff --git a/python/packages/purview/README.md b/python/packages/purview/README.md index a802cd9615..0a78e07605 100644 --- a/python/packages/purview/README.md +++ b/python/packages/purview/README.md @@ -320,4 +320,5 @@ except (PurviewAuthenticationError, PurviewRateLimitError, PurviewRequestError, - **Streaming Responses**: Post-response policy evaluation presently applies only to non-streaming chat responses. - **Error Handling**: Use `ignore_exceptions` and `ignore_payment_required` settings for graceful degradation. When enabled, errors are logged but don't fail the request. - **Caching**: Protection scopes responses and 402 errors are cached by default with a 4-hour TTL. Cache is automatically invalidated when protection scope state changes. +- **Cold-cache parallelization**: On a `ProtectionScopes` cache miss, scopes are refreshed in the background while `ProcessContent` runs in the foreground. - **Background Processing**: Content Activities and offline Process Content requests are handled asynchronously using background tasks to avoid blocking the main execution flow. diff --git a/python/packages/purview/agent_framework_purview/_processor.py b/python/packages/purview/agent_framework_purview/_processor.py index 241de80d61..eb949287fd 100644 --- a/python/packages/purview/agent_framework_purview/_processor.py +++ b/python/packages/purview/agent_framework_purview/_processor.py @@ -231,18 +231,19 @@ async def _process_with_scopes(self, pc_request: ProcessContentRequest) -> Proce cached_ps_resp = await self._cache.get(cache_key) if cached_ps_resp is not None and isinstance(cached_ps_resp, ProtectionScopesResponse): - ps_resp = cached_ps_resp - else: - ttl = self._settings.get("cache_ttl_seconds") - ttl_seconds = ttl if ttl is not None else 14400 - try: - ps_resp = await self._client.get_protection_scopes(ps_req) - await self._cache.set(cache_key, ps_resp, ttl_seconds=ttl_seconds) - except PurviewPaymentRequiredError as ex: - # Cache the exception at tenant level so all subsequent requests for this tenant fail fast - await self._cache.set(tenant_payment_cache_key, ex, ttl_seconds=ttl_seconds) - raise + return await self._process_with_cached_scopes(pc_request, cached_ps_resp, cache_key) + task = asyncio.create_task(self._refresh_protection_scopes_background(ps_req, cache_key, pc_request)) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + return await self._call_process_content(pc_request, cache_key, dlp_actions=[]) + + async def _process_with_cached_scopes( + self, + pc_request: ProcessContentRequest, + ps_resp: ProtectionScopesResponse, + cache_key: str, + ) -> ProcessContentResponse: if ps_resp.scope_identifier: pc_request.scope_identifier = ps_resp.scope_identifier @@ -259,13 +260,7 @@ async def _process_with_scopes(self, pc_request: ProcessContentRequest) -> Proce task.add_done_callback(self._background_tasks.discard) return ProcessContentResponse(id="204", correlation_id=pc_request.correlation_id) - pc_resp = await self._client.process_content(pc_request) - - if pc_request.scope_identifier and pc_resp.protection_scope_state == ProtectionScopeState.MODIFIED: - await self._cache.remove(cache_key) - - pc_resp.policy_actions = self._combine_policy_actions(pc_resp.policy_actions, dlp_actions) - return pc_resp + return await self._call_process_content(pc_request, cache_key, dlp_actions=dlp_actions) # No applicable scopes - send content activities in background ca_req = ContentActivitiesRequest( @@ -281,12 +276,52 @@ async def _process_with_scopes(self, pc_request: ProcessContentRequest) -> Proce # Respond with HttpStatusCode 204(No Content) return ProcessContentResponse(id="204", correlation_id=pc_request.correlation_id) + async def _call_process_content( + self, + pc_request: ProcessContentRequest, + cache_key: str, + dlp_actions: list[DlpActionInfo], + ) -> ProcessContentResponse: + pc_resp = await self._client.process_content(pc_request) + + if pc_request.scope_identifier and pc_resp.protection_scope_state == ProtectionScopeState.MODIFIED: + await self._cache.remove(cache_key) + + if dlp_actions: + pc_resp.policy_actions = self._combine_policy_actions(pc_resp.policy_actions, dlp_actions) + return pc_resp + + async def _refresh_protection_scopes_background( + self, ps_req: ProtectionScopesRequest, cache_key: str, pc_request: ProcessContentRequest + ) -> None: + """Fetch protection scopes and warm the cache without blocking the foreground call.""" + ttl = self._settings.get("cache_ttl_seconds") + ttl_seconds = ttl if ttl is not None else 14400 + try: + ps_resp = await self._client.get_protection_scopes(ps_req) + await self._cache.set(cache_key, ps_resp, ttl_seconds=ttl_seconds) + should_process, _, _ = self._check_applicable_scopes(pc_request, ps_resp) + if not should_process: + ca_req = ContentActivitiesRequest( + user_id=pc_request.user_id, + tenant_id=pc_request.tenant_id, + content_to_process=pc_request.content_to_process, + correlation_id=pc_request.correlation_id, + ) + await self._send_content_activities_background(ca_req) + except PurviewPaymentRequiredError as ex: + tenant_payment_cache_key = f"purview:payment_required:{ps_req.tenant_id}" + await self._cache.set(tenant_payment_cache_key, ex, ttl_seconds=ttl_seconds) + logger.warning("Background protection scopes refresh failed with payment required: %s", ex) + except Exception as ex: + logger.warning("Background protection scopes refresh failed: %s", ex) + async def _process_content_background(self, pc_request: ProcessContentRequest, cache_key: str) -> None: """Process content in background for offline execution mode.""" try: pc_resp = await self._client.process_content(pc_request) - # If protection scope state is modified, make another PC request and invalidate cache + # If protection scopes changed, invalidate cache and retry once. if pc_request.scope_identifier and pc_resp.protection_scope_state == ProtectionScopeState.MODIFIED: await self._cache.remove(cache_key) await self._client.process_content(pc_request) @@ -306,14 +341,10 @@ async def _send_content_activities_background(self, ca_req: ContentActivitiesReq def _combine_policy_actions( existing: list[DlpActionInfo] | None, new_actions: list[DlpActionInfo] ) -> list[DlpActionInfo]: - by_key: dict[str, DlpActionInfo] = {} - for a in existing or []: - if a.action: - by_key[a.action] = a - for a in new_actions: - if a.action: - by_key[a.action] = a - return list(by_key.values()) + combined: dict[tuple[DlpAction | None, RestrictionAction | None], DlpActionInfo] = {} + for action_info in (existing or []) + new_actions: + combined.setdefault((action_info.action, action_info.restriction_action), action_info) + return list(combined.values()) @staticmethod def _check_applicable_scopes( diff --git a/python/packages/purview/tests/purview/test_processor.py b/python/packages/purview/tests/purview/test_processor.py index 285fb338d8..0cc9d7a8a9 100644 --- a/python/packages/purview/tests/purview/test_processor.py +++ b/python/packages/purview/tests/purview/test_processor.py @@ -2,6 +2,7 @@ """Tests for Purview processor.""" +import asyncio from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -217,10 +218,38 @@ async def test_combine_policy_actions(self, processor: ScopedContentProcessor) - assert action1 in combined assert action2 in combined + async def test_combine_policy_actions_preserves_restriction_only_actions( + self, processor: ScopedContentProcessor + ) -> None: + """Test _combine_policy_actions keeps actions that only set restrictionAction.""" + existing_action = DlpActionInfo(action=DlpAction.OTHER, restrictionAction=RestrictionAction.OTHER) + restriction_only_action = DlpActionInfo(restriction_action=RestrictionAction.BLOCK) + + combined = processor._combine_policy_actions([existing_action], [restriction_only_action]) + + assert combined == [existing_action, restriction_only_action] + + async def test_combine_policy_actions_deduplicates_by_action_and_restriction( + self, processor: ScopedContentProcessor + ) -> None: + """Test _combine_policy_actions removes exact duplicate actions.""" + block_action = DlpActionInfo(action=DlpAction.BLOCK_ACCESS, restriction_action=RestrictionAction.BLOCK) + duplicate_block_action = DlpActionInfo( + action=DlpAction.BLOCK_ACCESS, restriction_action=RestrictionAction.BLOCK + ) + restriction_only_action = DlpActionInfo(restriction_action=RestrictionAction.BLOCK) + + combined = processor._combine_policy_actions( + [block_action], + [duplicate_block_action, restriction_only_action], + ) + + assert combined == [block_action, restriction_only_action] + async def test_process_with_scopes_calls_client_methods( self, processor: ScopedContentProcessor, mock_client: AsyncMock, process_content_request_factory ) -> None: - """Test _process_with_scopes calls get_protection_scopes when scopes response is empty.""" + """Test _process_with_scopes calls process_content immediately and warms scopes in background on cache miss.""" from agent_framework_purview._models import ( ContentActivitiesResponse, ProtectionScopesResponse, @@ -236,38 +265,91 @@ async def test_process_with_scopes_calls_client_methods( response = await processor._process_with_scopes(request) + # On cache miss, ProcessContent runs in the foreground and the response is returned. + assert response.id == "response-123" + mock_client.process_content.assert_called_once() + + # Protection scopes are refreshed in a background task. + await asyncio.gather(*list(processor._background_tasks)) mock_client.get_protection_scopes.assert_called_once() - # When no scopes apply, process_content is not called (activities are sent in background) - mock_client.process_content.assert_not_called() - # The response should have id=204 (No Content) when no scopes apply - assert response.id == "204" + mock_client.send_content_activities.assert_called_once() - async def test_process_with_scopes_ignores_unexpected_cached_value_type( + async def test_process_with_scopes_preserves_restriction_only_policy_actions( self, processor: ScopedContentProcessor, mock_client: AsyncMock, process_content_request_factory ) -> None: - """Test that a corrupted cache entry does not crash processing.""" + """Test cold-cache ProcessContent actions are not dropped when they only contain restrictionAction.""" + from agent_framework_purview._models import ProtectionScopesResponse + + request = process_content_request_factory() + restriction_only_action = DlpActionInfo(restriction_action=RestrictionAction.BLOCK) + + mock_client.get_protection_scopes = AsyncMock(return_value=ProtectionScopesResponse(**{"value": []})) + mock_client.process_content = AsyncMock( + return_value=ProcessContentResponse( + id="response-123", + protection_scope_state="notModified", + policy_actions=[restriction_only_action], + ) + ) + + response = await processor._process_with_scopes(request) + + assert response.policy_actions == [restriction_only_action] + await asyncio.gather(*list(processor._background_tasks)) + + async def test_process_with_cached_scopes_preserves_restriction_only_policy_actions( + self, processor: ScopedContentProcessor, mock_client: AsyncMock, process_content_request_factory + ) -> None: + """Test cached ProtectionScopes actions are not dropped when they only contain restrictionAction.""" from agent_framework_purview._models import ( ExecutionMode, PolicyLocation, PolicyScope, - ProcessContentResponse, ProtectionScopeActivities, ProtectionScopesResponse, ) request = process_content_request_factory() + restriction_only_action = DlpActionInfo(restriction_action=RestrictionAction.BLOCK) + process_content_action = DlpActionInfo(action=DlpAction.OTHER, restriction_action=RestrictionAction.OTHER) + scope_location = PolicyLocation( + data_type="microsoft.graph.policyLocationApplication", + value="app-id", + ) + scope = PolicyScope( + activities=ProtectionScopeActivities.UPLOAD_TEXT, + locations=[scope_location], + policy_actions=[restriction_only_action], + execution_mode=ExecutionMode.EVALUATE_INLINE, + ) - # Return a valid, inline scope so we stay on the normal (non-background) path. - scope_location = PolicyLocation(**{ - "@odata.type": "microsoft.graph.policyLocationApplication", - "value": "app-id", - }) - scope = PolicyScope(**{ - "activities": ProtectionScopeActivities.UPLOAD_TEXT, - "locations": [scope_location], - "execution_mode": ExecutionMode.EVALUATE_INLINE, - }) - mock_client.get_protection_scopes = AsyncMock(return_value=ProtectionScopesResponse(**{"value": [scope]})) + processor._cache.get = AsyncMock( + side_effect=[ + None, + ProtectionScopesResponse(scope_identifier="scope-123", scopes=[scope]), + ] + ) # type: ignore[method-assign] + mock_client.process_content = AsyncMock( + return_value=ProcessContentResponse( + id="response-123", + protection_scope_state="notModified", + policy_actions=[process_content_action], + ) + ) + + response = await processor._process_with_scopes(request) + + assert response.policy_actions == [process_content_action, restriction_only_action] + + async def test_process_with_scopes_ignores_unexpected_cached_value_type( + self, processor: ScopedContentProcessor, mock_client: AsyncMock, process_content_request_factory + ) -> None: + """Test that a corrupted cache entry does not crash processing.""" + from agent_framework_purview._models import ProtectionScopesResponse + + request = process_content_request_factory() + + mock_client.get_protection_scopes = AsyncMock(return_value=ProtectionScopesResponse(**{"value": []})) mock_client.process_content = AsyncMock( return_value=ProcessContentResponse(**{"id": "ok", "protectionScopeState": "notModified"}) ) @@ -279,8 +361,9 @@ async def test_process_with_scopes_ignores_unexpected_cached_value_type( response = await processor._process_with_scopes(request) assert response.id == "ok" - mock_client.get_protection_scopes.assert_called_once() mock_client.process_content.assert_called_once() + await asyncio.gather(*list(processor._background_tasks)) + mock_client.get_protection_scopes.assert_called_once() async def test_process_with_scopes_uses_tenant_payment_exception_cache( self, processor: ScopedContentProcessor, mock_client: AsyncMock, process_content_request_factory @@ -301,8 +384,6 @@ async def test_process_content_background_retries_on_modified_state( self, processor: ScopedContentProcessor, mock_client: AsyncMock, process_content_request_factory ) -> None: """Test offline background processing invalidates cache and retries when scope state changes.""" - from agent_framework_purview._models import ProcessContentResponse - request = process_content_request_factory() request.scope_identifier = "etag-1" @@ -319,6 +400,36 @@ async def test_process_content_background_retries_on_modified_state( processor._cache.remove.assert_called_once_with("purview:protection_scopes:abc") assert mock_client.process_content.call_count == 2 + async def test_background_scope_refresh_caches_payment_required( + self, mock_client: AsyncMock, process_content_request_factory + ) -> None: + """402 raised during background scope refresh is cached at the tenant level.""" + from agent_framework_purview._cache import InMemoryCacheProvider + from agent_framework_purview._exceptions import PurviewPaymentRequiredError + + settings = PurviewSettings( + app_name="Test App", + tenant_id="12345678-1234-1234-1234-123456789012", + purview_app_location=PurviewAppLocation( + location_type=PurviewLocationType.APPLICATION, location_value="app-id" + ), + ) + + cache = InMemoryCacheProvider() + processor = ScopedContentProcessor(mock_client, settings, cache_provider=cache) + + mock_client.get_protection_scopes = AsyncMock(side_effect=PurviewPaymentRequiredError("nope")) + mock_client.process_content = AsyncMock( + return_value=ProcessContentResponse(**{"id": "pc-1", "protectionScopeState": "notModified"}) + ) + + request = process_content_request_factory() + await processor._process_with_scopes(request) + await asyncio.gather(*list(processor._background_tasks)) + + cached = await cache.get(f"purview:payment_required:{request.tenant_id}") + assert isinstance(cached, PurviewPaymentRequiredError) + async def test_map_messages_with_user_id_in_additional_properties(self, mock_client: AsyncMock) -> None: """Test user_id extraction from message additional_properties.""" settings = PurviewSettings( @@ -387,6 +498,8 @@ async def test_process_content_sends_activities_when_not_applicable( self, mock_client: AsyncMock, process_content_request_factory ) -> None: """Test that response is returned when scopes don't apply (activities sent in background).""" + from agent_framework_purview._models import ProtectionScopesResponse + settings = PurviewSettings( app_name="Test App", tenant_id="12345678-1234-1234-1234-123456789012", @@ -398,10 +511,8 @@ async def test_process_content_sends_activities_when_not_applicable( pc_request = process_content_request_factory() - # Mock get_protection_scopes to return no applicable scopes - mock_ps_response = MagicMock() - mock_ps_response.scopes = [] - mock_client.get_protection_scopes.return_value = mock_ps_response + mock_ps_response = ProtectionScopesResponse(scopes=[]) + processor._cache.get = AsyncMock(side_effect=[None, mock_ps_response]) # type: ignore[method-assign] # Mock send_content_activities to return success (called in background) mock_ca_response = MagicMock() @@ -410,8 +521,10 @@ async def test_process_content_sends_activities_when_not_applicable( response = await processor._process_with_scopes(pc_request) - mock_client.get_protection_scopes.assert_called_once() + mock_client.get_protection_scopes.assert_not_called() mock_client.process_content.assert_not_called() + await asyncio.gather(*list(processor._background_tasks)) + mock_client.send_content_activities.assert_called_once() # Response should have id=204 when no scopes apply assert response.id == "204" @@ -419,6 +532,8 @@ async def test_process_content_handles_activities_error( self, mock_client: AsyncMock, process_content_request_factory ) -> None: """Test that errors in background activities don't affect the response.""" + from agent_framework_purview._models import ProtectionScopesResponse + settings = PurviewSettings( app_name="Test App", tenant_id="12345678-1234-1234-1234-123456789012", @@ -430,10 +545,8 @@ async def test_process_content_handles_activities_error( pc_request = process_content_request_factory() - # Mock get_protection_scopes to return no applicable scopes - mock_ps_response = MagicMock() - mock_ps_response.scopes = [] - mock_client.get_protection_scopes.return_value = mock_ps_response + mock_ps_response = ProtectionScopesResponse(scopes=[]) + processor._cache.get = AsyncMock(side_effect=[None, mock_ps_response]) # type: ignore[method-assign] # Mock send_content_activities to return error (called in background task) mock_ca_response = MagicMock() @@ -445,6 +558,8 @@ async def test_process_content_handles_activities_error( # Since activities are sent in background, errors don't affect the response # Response should have id=204 when no scopes apply assert response.id == "204" + await asyncio.gather(*list(processor._background_tasks)) + mock_client.send_content_activities.assert_called_once() class TestUserIdResolution: @@ -656,10 +771,12 @@ async def test_protection_scopes_cached_on_first_call( mock_client.get_protection_scopes.return_value = ProtectionScopesResponse( scope_identifier="scope-123", scopes=[] ) + mock_client.process_content.return_value = ProcessContentResponse(id="ok", protection_scope_state="notModified") messages = [Message(role="user", contents=["Test"])] await processor.process_messages(messages, Activity.UPLOAD_TEXT, user_id="12345678-1234-1234-1234-123456789012") + await asyncio.gather(*list(processor._background_tasks)) mock_client.get_protection_scopes.assert_called_once() @@ -670,7 +787,7 @@ async def test_protection_scopes_cached_on_first_call( async def test_payment_required_exception_cached_at_tenant_level( self, mock_client: AsyncMock, settings: PurviewSettings ) -> None: - """Test that 402 payment required exceptions are cached at tenant level.""" + """Test that background scope 402 returns once, then throws from the tenant-level cache.""" from agent_framework_purview._cache import InMemoryCacheProvider from agent_framework_purview._exceptions import PurviewPaymentRequiredError @@ -678,13 +795,12 @@ async def test_payment_required_exception_cached_at_tenant_level( processor = ScopedContentProcessor(mock_client, settings, cache_provider=cache_provider) mock_client.get_protection_scopes.side_effect = PurviewPaymentRequiredError("Payment required") + mock_client.process_content.return_value = ProcessContentResponse(id="ok", protection_scope_state="notModified") messages = [Message(role="user", contents=["Test"])] - with pytest.raises(PurviewPaymentRequiredError): - await processor.process_messages( - messages, Activity.UPLOAD_TEXT, user_id="12345678-1234-1234-1234-123456789012" - ) + await processor.process_messages(messages, Activity.UPLOAD_TEXT, user_id="12345678-1234-1234-1234-123456789012") + await asyncio.gather(*list(processor._background_tasks)) mock_client.get_protection_scopes.assert_called_once() diff --git a/python/samples/05-end-to-end/purview_agent/README.md b/python/samples/05-end-to-end/purview_agent/README.md index 1cdb7e3ef4..12293ec306 100644 --- a/python/samples/05-end-to-end/purview_agent/README.md +++ b/python/samples/05-end-to-end/purview_agent/README.md @@ -3,7 +3,7 @@ This getting-started sample shows how to attach Microsoft Purview policy evaluation to an Agent Framework `Agent` using the **middleware** approach. **What this sample demonstrates:** -1. Configure an Azure OpenAI chat client +1. Configure a Foundry chat client 2. Add Purview policy enforcement middleware (`PurviewPolicyMiddleware`) 3. Add Purview policy enforcement at the chat client level (`PurviewChatPolicyMiddleware`) 4. Implement a custom cache provider for advanced caching scenarios @@ -17,8 +17,8 @@ This getting-started sample shows how to attach Microsoft Purview policy evaluat | Variable | Required | Purpose | |----------|----------|---------| -| `AZURE_OPENAI_ENDPOINT` | Yes | Azure OpenAI endpoint (https://.openai.azure.com) | -| `AZURE_OPENAI_MODEL` | Optional | Model deployment name (defaults inside SDK if omitted) | +| `FOUNDRY_PROJECT_ENDPOINT` | Yes | Azure AI Foundry project endpoint, for example `https://.services.ai.azure.com/api/projects/` | +| `FOUNDRY_MODEL` | Optional | Model deployment name (defaults to `gpt-4o-mini`) | | `PURVIEW_CLIENT_APP_ID` | Yes* | Client (application) ID used for Purview authentication | | `PURVIEW_USE_CERT_AUTH` | Optional (`true`/`false`) | Switch between certificate and interactive auth | | `PURVIEW_TENANT_ID` | Yes (when cert auth on) | Tenant ID for certificate authentication | @@ -31,7 +31,8 @@ This getting-started sample shows how to attach Microsoft Purview policy evaluat Opens a browser on first run to sign in. ```powershell -$env:AZURE_OPENAI_ENDPOINT = "https://your-openai-instance.openai.azure.com" +$env:FOUNDRY_PROJECT_ENDPOINT = "https://.services.ai.azure.com/api/projects/" +$env:FOUNDRY_MODEL = "gpt-4o-mini" $env:PURVIEW_CLIENT_APP_ID = "00000000-0000-0000-0000-000000000000" ``` @@ -64,22 +65,27 @@ If interactive auth is used, a browser window will appear the first time. ## 4. How It Works -The sample demonstrates three different scenarios: +The sample demonstrates four integration scenarios. Each scenario runs the same three-message sequence via `run_policy_flow(...)`: + +1. **good (cold cache)** - a benign prompt that exercises the cold-cache parallel ProtectionScopes warmup + foreground ProcessContent path. +2. **expected block** - a sensitive prompt containing the Visa test credit card number `4111 1111 1111 1111`. If the tenant has a DLP policy for `Microsoft 365 Copilot and AI apps` targeting the Credit Card sensitive info type with a Block action, this prompt returns the configured `blocked_prompt_message` (default: `Prompt blocked by policy`). If no DLP policy applies, the prompt is allowed (the LLM may still decline on its own, but that is a model-level response, not a Purview block). +3. **good (warm cache)** - a second benign prompt that exercises the warm-cache path. The custom cache provider scenario prints `Cache HIT` for the same protection-scopes key, confirming the cache and middleware state survive a prior block. ### A. Agent Middleware (`run_with_agent_middleware`) -1. Builds an Azure OpenAI chat client (using the environment endpoint / deployment) +1. Builds a Foundry chat client (using the environment project endpoint / deployment) 2. Chooses credential mode (certificate vs interactive) 3. Creates `PurviewPolicyMiddleware` with `PurviewSettings` 4. Injects middleware into the agent at construction -5. Sends two user messages sequentially -6. Prints results (or policy block messages) +5. Runs the three-message `good -> block -> good` orchestration +6. Prints `ALLOWED` or `BLOCKED` per message, plus the model response 7. Uses default caching automatically ### B. Chat Client Middleware (`run_with_chat_middleware`) 1. Creates a chat client with `PurviewChatPolicyMiddleware` attached directly 2. Policy evaluation happens at the chat client level rather than agent level 3. Demonstrates an alternative integration point for Purview policies -4. Uses default caching automatically +4. Runs the same `good -> block -> good` orchestration +5. Uses default caching automatically ### C. Custom Cache Provider (`run_with_custom_cache_provider`) 1. Implements the `CacheProvider` protocol with a custom class (`SimpleDictCacheProvider`) @@ -88,9 +94,27 @@ The sample demonstrates three different scenarios: - `async def get(self, key: str) -> Any | None` - `async def set(self, key: str, value: Any, ttl_seconds: int | None = None) -> None` - `async def remove(self, key: str) -> None` +4. Runs the `good -> block -> good` orchestration and prints `Cache MISS`/`Cache HIT` traces alongside policy outcomes, showing the cold-cache warmup populating the cache and warm-cache requests skipping ProtectionScopes. + +### D. Default Cache (`run_with_default_cache`) +1. Same as the agent middleware path but with explicit cache TTL and size limits in `PurviewSettings` +2. Uses the default in-memory `CacheProvider` +3. Runs the `good -> block -> good` orchestration **Policy Behavior:** -Prompt blocks set a system-level message: `Prompt blocked by policy` and terminate the run early. Response blocks rewrite the output to `Response blocked by policy`. +Prompt blocks substitute the configured `blocked_prompt_message` (default `Prompt blocked by policy`) and terminate the agent run early. Response blocks substitute `blocked_response_message`. The LLM is never called for a blocked prompt. + +**Seeing a real `BLOCKED` outcome:** +The middle prompt only returns `BLOCKED` if the tenant actually has a Purview DLP policy that matches the request. Specifically, all of the following must be true: + +1. The Entra app id used by `PURVIEW_CLIENT_APP_ID` (the same id Agent Framework sends as `policyLocationApplication.value`) is registered as an integrated AI app in Purview (Settings -> AI app and agent locations). +2. A DLP policy in the tenant targets the location `Microsoft 365 Copilot and AI apps`, scoped to that app id (or `All apps`). +3. The policy has a rule with the condition `Content contains -> Sensitive info types -> Credit Card Number` and an action of `Restrict access to Microsoft 365 Copilot and AI apps -> Block`. +4. The policy is `On` (not `Test mode without notifications`). +5. The signed-in user is in the policy's user scope. +6. Required Graph delegated permissions are admin-consented: `ProtectionScopes.Compute.All`, `Content.Process.All`, `ContentActivity.Write`. + +If any of those are missing, the credit card prompt is allowed at the Purview layer. The model itself may still decline on its own; that response is a model-level refusal, not a Purview block. The cold/warm cache orchestration is still demonstrated either way - the `Cache MISS -> Cache HIT` trace from the custom cache scenario does not depend on a block firing. --- diff --git a/python/samples/05-end-to-end/purview_agent/sample_purview_agent.py b/python/samples/05-end-to-end/purview_agent/sample_purview_agent.py index 5eb2845886..7305ea12e8 100644 --- a/python/samples/05-end-to-end/purview_agent/sample_purview_agent.py +++ b/python/samples/05-end-to-end/purview_agent/sample_purview_agent.py @@ -11,8 +11,8 @@ Note: Caching is automatic and enabled by default. Environment variables: -- AZURE_OPENAI_ENDPOINT (required) -- AZURE_OPENAI_MODEL (optional, defaults to gpt-4o-mini) +- FOUNDRY_PROJECT_ENDPOINT (required) - Azure AI Foundry project endpoint URL +- FOUNDRY_MODEL (optional, defaults to gpt-4o-mini) - PURVIEW_CLIENT_APP_ID (required) - PURVIEW_USE_CERT_AUTH (optional, set to "true" for certificate auth) - PURVIEW_TENANT_ID (required if certificate auth) @@ -45,6 +45,37 @@ JOKER_NAME = "Joker" JOKER_INSTRUCTIONS = "You are good at telling jokes. Keep responses concise." +# Sequential prompts to demonstrate good -> block -> good orchestration. +# The sensitive prompt contains a Visa test credit card number that matches Purview's +# built-in Credit Card sensitive information type. If the tenant has a DLP policy that +# blocks credit card content for Microsoft 365 Copilot and AI apps, the second message +# will be blocked and the third will verify that subsequent calls still flow normally +# after a block. +GOOD_PROMPT_PRIMARY = "Tell me a joke about a pirate." +SENSITIVE_PROMPT = "My corporate credit card is 4111 1111 1111 1111. Please confirm receipt." +GOOD_PROMPT_FOLLOWUP = "Another light joke please." + + +async def run_policy_flow( + label: str, + agent: Agent, + user_id: str | None, + blocked_text: str, +) -> None: + """Run a good -> block candidate -> good sequence and report each outcome.""" + blocked_marker = blocked_text.lower() + prompts = [ + ("good (cold cache)", GOOD_PROMPT_PRIMARY), + ("expected block", SENSITIVE_PROMPT), + ("good (warm cache)", GOOD_PROMPT_FOLLOWUP), + ] + for tag, text in prompts: + response: AgentResponse = await agent.run( + Message("user", [text], additional_properties={"user_id": user_id}) + ) + outcome = "BLOCKED" if blocked_marker in str(response).lower() else "ALLOWED" + print(f"[{label}] {tag}: {outcome}\n{response}\n") + # Custom Cache Provider Implementation class SimpleDictCacheProvider: @@ -138,21 +169,17 @@ def build_credential() -> Any: async def run_with_agent_middleware() -> None: - endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT") + endpoint = os.environ.get("FOUNDRY_PROJECT_ENDPOINT") if not endpoint: - print("Skipping run: AZURE_OPENAI_ENDPOINT not set") + print("Skipping run: FOUNDRY_PROJECT_ENDPOINT not set") return - deployment = os.environ.get("AZURE_OPENAI_MODEL", "gpt-4o-mini") + deployment = os.environ.get("FOUNDRY_MODEL", "gpt-4o-mini") user_id = os.environ.get("PURVIEW_DEFAULT_USER_ID") - client = FoundryChatClient(model=deployment, endpoint=endpoint, credential=AzureCliCredential()) + client = FoundryChatClient(model=deployment, project_endpoint=endpoint, credential=AzureCliCredential()) - purview_agent_middleware = PurviewPolicyMiddleware( - build_credential(), - PurviewSettings( - app_name="Agent Framework Sample App", - ), - ) + settings = PurviewSettings(app_name="Agent Framework Sample App") + purview_agent_middleware = PurviewPolicyMiddleware(build_credential(), settings) agent = Agent( client=client, @@ -162,39 +189,26 @@ async def run_with_agent_middleware() -> None: ) print("-- Agent MiddlewareTypes Path --") - first: AgentResponse = await agent.run( - Message("user", ["Tell me a joke about a pirate."], additional_properties={"user_id": user_id}) - ) - print("First response (agent middleware):\n", first) - - second: AgentResponse = await agent.run( - Message( - role="user", contents=["That was funny. Tell me another one."], additional_properties={"user_id": user_id} - ) - ) - print("Second response (agent middleware):\n", second) + blocked_text = settings.get("blocked_prompt_message") or "Prompt blocked by policy" + await run_policy_flow("agent middleware", agent, user_id, blocked_text) async def run_with_chat_middleware() -> None: - endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT") + endpoint = os.environ.get("FOUNDRY_PROJECT_ENDPOINT") if not endpoint: - print("Skipping chat middleware run: AZURE_OPENAI_ENDPOINT not set") + print("Skipping chat middleware run: FOUNDRY_PROJECT_ENDPOINT not set") return - deployment = os.environ.get("AZURE_OPENAI_MODEL", default="gpt-4o-mini") + deployment = os.environ.get("FOUNDRY_MODEL", default="gpt-4o-mini") user_id = os.environ.get("PURVIEW_DEFAULT_USER_ID") + settings = PurviewSettings(app_name="Agent Framework Sample App (Chat)") client = FoundryChatClient( model=deployment, - endpoint=endpoint, + project_endpoint=endpoint, credential=AzureCliCredential(), middleware=[ - PurviewChatPolicyMiddleware( - build_credential(), - PurviewSettings( - app_name="Agent Framework Sample App (Chat)", - ), - ) + PurviewChatPolicyMiddleware(build_credential(), settings) ], ) @@ -205,43 +219,27 @@ async def run_with_chat_middleware() -> None: ) print("-- Chat MiddlewareTypes Path --") - first: AgentResponse = await agent.run( - Message( - role="user", - contents=["Give me a short clean joke."], - additional_properties={"user_id": user_id}, - ) - ) - print("First response (chat middleware):\n", first) - - second: AgentResponse = await agent.run( - Message( - role="user", - contents=["One more please."], - additional_properties={"user_id": user_id}, - ) - ) - print("Second response (chat middleware):\n", second) + blocked_text = settings.get("blocked_prompt_message") or "Prompt blocked by policy" + await run_policy_flow("chat middleware", agent, user_id, blocked_text) async def run_with_custom_cache_provider() -> None: """Demonstrate implementing and using a custom cache provider.""" - endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT") + endpoint = os.environ.get("FOUNDRY_PROJECT_ENDPOINT") if not endpoint: - print("Skipping custom cache provider run: AZURE_OPENAI_ENDPOINT not set") + print("Skipping custom cache provider run: FOUNDRY_PROJECT_ENDPOINT not set") return - deployment = os.environ.get("AZURE_OPENAI_MODEL", "gpt-4o-mini") + deployment = os.environ.get("FOUNDRY_MODEL", "gpt-4o-mini") user_id = os.environ.get("PURVIEW_DEFAULT_USER_ID") - client = FoundryChatClient(model=deployment, endpoint=endpoint, credential=AzureCliCredential()) + client = FoundryChatClient(model=deployment, project_endpoint=endpoint, credential=AzureCliCredential()) custom_cache = SimpleDictCacheProvider() + settings = PurviewSettings(app_name="Agent Framework Sample App (Custom Provider)") purview_agent_middleware = PurviewPolicyMiddleware( build_credential(), - PurviewSettings( - app_name="Agent Framework Sample App (Custom Provider)", - ), + settings, cache_provider=custom_cache, ) @@ -254,38 +252,28 @@ async def run_with_custom_cache_provider() -> None: print("-- Custom Cache Provider Path --") print("Using SimpleDictCacheProvider") + blocked_text = settings.get("blocked_prompt_message") or "Prompt blocked by policy" + await run_policy_flow("custom cache", agent, user_id, blocked_text) - first: AgentResponse = await agent.run( - Message( - role="user", contents=["Tell me a joke about a programmer."], additional_properties={"user_id": user_id} - ) - ) - print("First response (custom provider):\n", first) - - second: AgentResponse = await agent.run( - Message("user", ["That's hilarious! One more?"], additional_properties={"user_id": user_id}) - ) - print("Second response (custom provider):\n", second) +async def run_with_default_cache() -> None: """Demonstrate using the default built-in cache.""" - endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT") + endpoint = os.environ.get("FOUNDRY_PROJECT_ENDPOINT") if not endpoint: - print("Skipping default cache run: AZURE_OPENAI_ENDPOINT not set") + print("Skipping default cache run: FOUNDRY_PROJECT_ENDPOINT not set") return - deployment = os.environ.get("AZURE_OPENAI_MODEL", "gpt-4o-mini") + deployment = os.environ.get("FOUNDRY_MODEL", "gpt-4o-mini") user_id = os.environ.get("PURVIEW_DEFAULT_USER_ID") - client = FoundryChatClient(model=deployment, endpoint=endpoint, credential=AzureCliCredential()) + client = FoundryChatClient(model=deployment, project_endpoint=endpoint, credential=AzureCliCredential()) # No cache_provider specified - uses default InMemoryCacheProvider - purview_agent_middleware = PurviewPolicyMiddleware( - build_credential(), - PurviewSettings( - app_name="Agent Framework Sample App (Default Cache)", - cache_ttl_seconds=3600, - max_cache_size_bytes=100 * 1024 * 1024, # 100MB - ), + settings = PurviewSettings( + app_name="Agent Framework Sample App (Default Cache)", + cache_ttl_seconds=3600, + max_cache_size_bytes=100 * 1024 * 1024, # 100MB ) + purview_agent_middleware = PurviewPolicyMiddleware(build_credential(), settings) agent = Agent( client=client, @@ -296,16 +284,8 @@ async def run_with_custom_cache_provider() -> None: print("-- Default Cache Path --") print("Using default InMemoryCacheProvider with settings-based configuration") - - first: AgentResponse = await agent.run( - Message("user", ["Tell me a joke about AI."], additional_properties={"user_id": user_id}) - ) - print("First response (default cache):\n", first) - - second: AgentResponse = await agent.run( - Message("user", ["Nice! Another AI joke please."], additional_properties={"user_id": user_id}) - ) - print("Second response (default cache):\n", second) + blocked_text = settings.get("blocked_prompt_message") or "Prompt blocked by policy" + await run_policy_flow("default cache", agent, user_id, blocked_text) async def main() -> None: @@ -326,6 +306,11 @@ async def main() -> None: except Exception as ex: # pragma: no cover - demo resilience print(f"Custom cache provider path failed: {ex}") + try: + await run_with_default_cache() + except Exception as ex: # pragma: no cover - demo resilience + print(f"Default cache path failed: {ex}") + if __name__ == "__main__": asyncio.run(main())