Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions libs/langchain_v1/langchain/agents/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from langgraph.constants import END, START
from langgraph.graph.state import StateGraph
from langgraph.prebuilt.tool_node import ToolCallWithContext, ToolNode
from langgraph.runtime import Runtime # noqa: TC002
from langgraph.runtime import Runtime
from langgraph.types import Command, Send
from langgraph.typing import ContextT # noqa: TC002
from langgraph.typing import ContextT
from typing_extensions import NotRequired, Required, TypedDict

from langchain.agents.middleware.types import (
Expand Down Expand Up @@ -541,7 +541,7 @@ async def call_inner(req: ToolCallRequest) -> ToolMessage | Command:
return result


def create_agent( # noqa: PLR0915
def create_agent(
model: str | BaseChatModel,
tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None,
*,
Expand Down Expand Up @@ -786,9 +786,9 @@ def check_weather(location: str) -> str:
default_tools = list(built_in_tools)

# validate middleware
assert len({m.name for m in middleware}) == len(middleware), ( # noqa: S101
"Please remove duplicate middleware instances."
)
if len({m.name for m in middleware}) != len(middleware):
msg = "Please remove duplicate middleware instances."
raise AssertionError(msg)
middleware_w_before_agent = [
m
for m in middleware
Expand Down Expand Up @@ -881,12 +881,12 @@ def _handle_model_output(
)
try:
structured_response = provider_strategy_binding.parse(output)
except Exception as exc: # noqa: BLE001
except Exception as exc:
schema_name = getattr(
effective_response_format.schema_spec.schema, "__name__", "response_format"
)
validation_error = StructuredOutputValidationError(schema_name, exc, output)
raise validation_error
raise validation_error from exc
else:
return {"messages": [output], "structured_response": structured_response}
return {"messages": [output]}
Expand Down Expand Up @@ -947,13 +947,13 @@ def _handle_model_output(
],
"structured_response": structured_response,
}
except Exception as exc: # noqa: BLE001
except Exception as exc:
exception = StructuredOutputValidationError(tool_call["name"], exc, output)
should_retry, error_message = _handle_structured_output_error(
exception, effective_response_format
)
if not should_retry:
raise exception
raise exception from exc

return {
"messages": [
Expand Down
30 changes: 22 additions & 8 deletions libs/langchain_v1/langchain/agents/middleware/_redaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,19 +179,22 @@ def detect_url(content: str) -> list[PIIMatch]:
}
"""Registry of built-in detectors keyed by type name."""

_CARD_NUMBER_MIN_DIGITS = 13
_CARD_NUMBER_MAX_DIGITS = 19


def _passes_luhn(card_number: str) -> bool:
"""Validate credit card number using the Luhn checksum."""
digits = [int(d) for d in card_number if d.isdigit()]
if not 13 <= len(digits) <= 19:
if not _CARD_NUMBER_MIN_DIGITS <= len(digits) <= _CARD_NUMBER_MAX_DIGITS:
return False

checksum = 0
for index, digit in enumerate(reversed(digits)):
value = digit
if index % 2 == 1:
value *= 2
if value > 9:
if value > 9: # noqa: PLR2004
value -= 9
checksum += value
return checksum % 10 == 0
Expand All @@ -205,18 +208,22 @@ def _apply_redact_strategy(content: str, matches: list[PIIMatch]) -> str:
return result


_UNMASKED_CHAR_NUMBER = 4
_IPV4_PARTS_NUMBER = 4


def _apply_mask_strategy(content: str, matches: list[PIIMatch]) -> str:
result = content
for match in sorted(matches, key=lambda item: item["start"], reverse=True):
value = match["value"]
pii_type = match["type"]
if pii_type == "email":
parts = value.split("@")
if len(parts) == 2:
if len(parts) == 2: # noqa: PLR2004
domain_parts = parts[1].split(".")
masked = (
f"{parts[0]}@****.{domain_parts[-1]}"
if len(domain_parts) >= 2
if len(domain_parts) > 1
else f"{parts[0]}@****"
)
else:
Expand All @@ -225,12 +232,15 @@ def _apply_mask_strategy(content: str, matches: list[PIIMatch]) -> str:
digits_only = "".join(c for c in value if c.isdigit())
separator = "-" if "-" in value else " " if " " in value else ""
if separator:
masked = f"****{separator}****{separator}****{separator}{digits_only[-4:]}"
masked = (
f"****{separator}****{separator}****{separator}"
f"{digits_only[-_UNMASKED_CHAR_NUMBER:]}"
)
else:
masked = f"************{digits_only[-4:]}"
masked = f"************{digits_only[-_UNMASKED_CHAR_NUMBER:]}"
elif pii_type == "ip":
octets = value.split(".")
masked = f"*.*.*.{octets[-1]}" if len(octets) == 4 else "****"
masked = f"*.*.*.{octets[-1]}" if len(octets) == _IPV4_PARTS_NUMBER else "****"
elif pii_type == "mac_address":
separator = ":" if ":" in value else "-"
masked = (
Expand All @@ -239,7 +249,11 @@ def _apply_mask_strategy(content: str, matches: list[PIIMatch]) -> str:
elif pii_type == "url":
masked = "[MASKED_URL]"
else:
masked = f"****{value[-4:]}" if len(value) > 4 else "****"
masked = (
f"****{value[-_UNMASKED_CHAR_NUMBER:]}"
if len(value) > _UNMASKED_CHAR_NUMBER
else "****"
)
result = result[: match["start"]] + masked + result[match["end"] :]
return result

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def wrap_model_call(

def count_tokens(messages: Sequence[BaseMessage]) -> int:
return count_tokens_approximately(messages)

else:
system_msg = (
[SystemMessage(content=request.system_prompt)] if request.system_prompt else []
Expand Down Expand Up @@ -258,6 +259,7 @@ async def awrap_model_call(

def count_tokens(messages: Sequence[BaseMessage]) -> int:
return count_tokens_approximately(messages)

else:
system_msg = (
[SystemMessage(content=request.system_prompt)] if request.system_prompt else []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from langchain_core.messages import AIMessage
from langgraph.channels.untracked_value import UntrackedValue
from typing_extensions import NotRequired
from typing_extensions import NotRequired, override

from langchain.agents.middleware.types import (
AgentMiddleware,
Expand Down Expand Up @@ -157,7 +157,8 @@ def __init__(
self.exit_behavior = exit_behavior

@hook_config(can_jump_to=["end"])
def before_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
@override
def before_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None:
"""Check model call limits before making a model call.

Args:
Expand Down Expand Up @@ -222,7 +223,8 @@ async def abefore_model(
"""
return self.before_model(state, runtime)

def after_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
@override
def after_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None:
"""Increment model call counts after a model call.

Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ def wrap_model_call(
last_exception: Exception
try:
return handler(request)
except Exception as e: # noqa: BLE001
except Exception as e:
last_exception = e

# Try fallback models
for fallback_model in self.models:
try:
return handler(request.override(model=fallback_model))
except Exception as e: # noqa: BLE001
except Exception as e:
last_exception = e
continue

Expand All @@ -121,14 +121,14 @@ async def awrap_model_call(
last_exception: Exception
try:
return await handler(request)
except Exception as e: # noqa: BLE001
except Exception as e:
last_exception = e

# Try fallback models
for fallback_model in self.models:
try:
return await handler(request.override(model=fallback_model))
except Exception as e: # noqa: BLE001
except Exception as e:
last_exception = e
continue

Expand Down
4 changes: 2 additions & 2 deletions libs/langchain_v1/langchain/agents/middleware/model_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def wrap_model_call(
for attempt in range(self.max_retries + 1):
try:
return handler(request)
except Exception as exc: # noqa: BLE001
except Exception as exc:
attempts_made = attempt + 1 # attempt is 0-indexed

# Check if we should retry this exception
Expand Down Expand Up @@ -270,7 +270,7 @@ async def awrap_model_call(
for attempt in range(self.max_retries + 1):
try:
return await handler(request)
except Exception as exc: # noqa: BLE001
except Exception as exc:
attempts_made = attempt + 1 # attempt is 0-indexed

# Check if we should retry this exception
Expand Down
9 changes: 7 additions & 2 deletions libs/langchain_v1/langchain/agents/middleware/pii.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import TYPE_CHECKING, Any, Literal

from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, ToolMessage
from typing_extensions import override

from langchain.agents.middleware._redaction import (
PIIDetectionError,
Expand Down Expand Up @@ -92,6 +93,8 @@ class PIIMiddleware(AgentMiddleware):

def __init__(
self,
# From a typing point of view, the literals are covered by 'str'.
# Nonetheless, we escape PYI051 to keep hints and autocompletion for the caller.
pii_type: Literal["email", "credit_card", "ip", "mac_address", "url"] | str, # noqa: PYI051
*,
strategy: Literal["block", "redact", "mask", "hash"] = "redact",
Expand Down Expand Up @@ -158,10 +161,11 @@ def _process_content(self, content: str) -> tuple[str, list[PIIMatch]]:
return sanitized, matches

@hook_config(can_jump_to=["end"])
@override
def before_model(
self,
state: AgentState,
runtime: Runtime, # noqa: ARG002
runtime: Runtime,
) -> dict[str, Any] | None:
"""Check user messages and tool results for PII before model invocation.

Expand Down Expand Up @@ -273,10 +277,11 @@ async def abefore_model(
"""
return self.before_model(state, runtime)

@override
def after_model(
self,
state: AgentState,
runtime: Runtime, # noqa: ARG002
runtime: Runtime,
) -> dict[str, Any] | None:
"""Check AI messages for PII after model invocation.

Expand Down
14 changes: 8 additions & 6 deletions libs/langchain_v1/langchain/agents/middleware/shell_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from langgraph.channels.untracked_value import UntrackedValue
from pydantic import BaseModel, model_validator
from pydantic.json_schema import SkipJsonSchema
from typing_extensions import NotRequired
from typing_extensions import NotRequired, override

from langchain.agents.middleware._execution import (
SHELL_TEMP_PREFIX,
Expand Down Expand Up @@ -78,10 +78,10 @@ class _SessionResources:
session: ShellSession
tempdir: tempfile.TemporaryDirectory[str] | None
policy: BaseExecutionPolicy
_finalizer: weakref.finalize = field(init=False, repr=False)
finalizer: weakref.finalize = field(init=False, repr=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is accessed in ShellToolMiddleware.after_agent():


_SessionResources is already private so it doesn't change the exposure a lot.


def __post_init__(self) -> None:
self._finalizer = weakref.finalize(
self.finalizer = weakref.finalize(
self,
_cleanup_resources,
self.session,
Expand Down Expand Up @@ -489,7 +489,8 @@ def _normalize_env(env: Mapping[str, Any] | None) -> dict[str, str] | None:
normalized[key] = str(value)
return normalized

def before_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
@override
def before_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None:
"""Start the shell session and run startup commands."""
resources = self._get_or_create_resources(state)
return {"shell_session_resources": resources}
Expand All @@ -498,7 +499,8 @@ async def abefore_agent(self, state: ShellToolState, runtime: Runtime) -> dict[s
"""Async start the shell session and run startup commands."""
return self.before_agent(state, runtime)

def after_agent(self, state: ShellToolState, runtime: Runtime) -> None: # noqa: ARG002
@override
def after_agent(self, state: ShellToolState, runtime: Runtime) -> None:
"""Run shutdown commands and release resources when an agent completes."""
resources = state.get("shell_session_resources")
if not isinstance(resources, _SessionResources):
Expand All @@ -507,7 +509,7 @@ def after_agent(self, state: ShellToolState, runtime: Runtime) -> None: # noqa:
try:
self._run_shutdown_commands(resources.session)
finally:
resources._finalizer()
resources.finalizer()

async def aafter_agent(self, state: ShellToolState, runtime: Runtime) -> None:
"""Async run shutdown commands and release resources when an agent completes."""
Expand Down
13 changes: 8 additions & 5 deletions libs/langchain_v1/langchain/agents/middleware/summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
REMOVE_ALL_MESSAGES,
)
from langgraph.runtime import Runtime
from typing_extensions import override

from langchain.agents.middleware.types import AgentMiddleware, AgentState
from langchain.chat_models import BaseChatModel, init_chat_model
Expand Down Expand Up @@ -165,7 +166,8 @@ def __init__(
)
raise ValueError(msg)

def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
@override
def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
"""Process messages before model invocation, potentially triggering summarization."""
messages = state["messages"]
self._ensure_message_ids(messages)
Expand All @@ -192,7 +194,8 @@ def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] |
]
}

async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
@override
async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
"""Process messages before model invocation, potentially triggering summarization."""
messages = state["messages"]
self._ensure_message_ids(messages)
Expand Down Expand Up @@ -438,7 +441,7 @@ def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
try:
response = self.model.invoke(self.summary_prompt.format(messages=trimmed_messages))
return response.text.strip()
except Exception as e: # noqa: BLE001
except Exception as e:
return f"Error generating summary: {e!s}"

async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
Expand All @@ -455,7 +458,7 @@ async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str
self.summary_prompt.format(messages=trimmed_messages)
)
return response.text.strip()
except Exception as e: # noqa: BLE001
except Exception as e:
return f"Error generating summary: {e!s}"

def _trim_messages_for_summary(self, messages: list[AnyMessage]) -> list[AnyMessage]:
Expand All @@ -472,5 +475,5 @@ def _trim_messages_for_summary(self, messages: list[AnyMessage]) -> list[AnyMess
allow_partial=True,
include_system=True,
)
except Exception: # noqa: BLE001
except Exception:
return messages[-_DEFAULT_FALLBACK_MESSAGE_COUNT:]
Loading
Loading