Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 0 additions & 13 deletions cognee/infrastructure/llm/LLMGateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,6 @@ def acreate_structured_output(
text_input=text_input, system_prompt=system_prompt, response_model=response_model
)

@staticmethod
def create_structured_output(
text_input: str, system_prompt: str, response_model: Type[BaseModel]
) -> BaseModel:
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
get_llm_client,
)

llm_client = get_llm_client()
return llm_client.create_structured_output(
text_input=text_input, system_prompt=system_prompt, response_model=response_model
)

@staticmethod
def create_transcript(input) -> Coroutine:
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.get_llm_client import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from pydantic import BaseModel
import litellm
import instructor
import anthropic
from cognee.shared.logging_utils import get_logger
from cognee.modules.observability.get_observe import get_observe
from tenacity import (
retry,
stop_after_delay,
Expand All @@ -12,37 +14,40 @@
before_sleep_log,
)

from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface,
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
GenericAPIAdapter,
)
from cognee.infrastructure.llm.config import get_llm_config

logger = get_logger()
observe = get_observe()


class AnthropicAdapter(LLMInterface):
class AnthropicAdapter(GenericAPIAdapter):
"""
Adapter for interfacing with the Anthropic API, enabling structured output generation
and prompt display.
"""

name = "Anthropic"
model: str
default_instructor_mode = "anthropic_tools"

def __init__(self, max_completion_tokens: int, model: str = None, instructor_mode: str = None):
import anthropic

def __init__(
self, api_key: str, model: str, max_completion_tokens: int, instructor_mode: str = None
):
super().__init__(
api_key=api_key,
model=model,
max_completion_tokens=max_completion_tokens,
name="Anthropic",
)
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode

self.aclient = instructor.patch(
create=anthropic.AsyncAnthropic(api_key=get_llm_config().llm_api_key).messages.create,
mode=instructor.Mode(self.instructor_mode),
)

self.model = model
self.max_completion_tokens = max_completion_tokens

@observe(as_type="generation")
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Adapter for Generic API LLM provider API"""
"""Adapter for Gemini API LLM provider"""

import litellm
import instructor
Expand All @@ -8,12 +8,7 @@
from litellm.exceptions import ContentPolicyViolationError
from instructor.core import InstructorRetryException

from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface,
)
import logging
from cognee.shared.logging_utils import get_logger
from tenacity import (
retry,
stop_after_delay,
Expand All @@ -22,55 +17,65 @@
before_sleep_log,
)

from cognee.infrastructure.llm.exceptions import ContentPolicyFilterError
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.generic_llm_api.adapter import (
GenericAPIAdapter,
)
from cognee.shared.logging_utils import get_logger
from cognee.modules.observability.get_observe import get_observe

logger = get_logger()
observe = get_observe()


class GeminiAdapter(LLMInterface):
class GeminiAdapter(GenericAPIAdapter):
"""
Adapter for Gemini API LLM provider.

This class initializes the API adapter with necessary credentials and configurations for
interacting with the gemini LLM models. It provides methods for creating structured outputs
based on user input and system prompts.
based on user input and system prompts, as well as multimodal processing capabilities.

Public methods:
- acreate_structured_output(text_input: str, system_prompt: str, response_model:
Type[BaseModel]) -> BaseModel
- acreate_structured_output(text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel
- create_transcript(input) -> BaseModel: Transcribe audio files to text
- transcribe_image(input) -> BaseModel: Inherited from GenericAPIAdapter
"""

name: str
model: str
api_key: str
default_instructor_mode = "json_mode"

def __init__(
self,
endpoint,
api_key: str,
model: str,
api_version: str,
max_completion_tokens: int,
endpoint: str = None,
api_version: str = None,
transcription_model: str = None,
instructor_mode: str = None,
fallback_model: str = None,
fallback_api_key: str = None,
fallback_endpoint: str = None,
):
self.model = model
self.api_key = api_key
self.endpoint = endpoint
self.api_version = api_version
self.max_completion_tokens = max_completion_tokens

self.fallback_model = fallback_model
self.fallback_api_key = fallback_api_key
self.fallback_endpoint = fallback_endpoint

super().__init__(
api_key=api_key,
model=model,
max_completion_tokens=max_completion_tokens,
name="Gemini",
endpoint=endpoint,
api_version=api_version,
transcription_model=transcription_model,
fallback_model=fallback_model,
fallback_api_key=fallback_api_key,
fallback_endpoint=fallback_endpoint,
)
self.instructor_mode = instructor_mode if instructor_mode else self.default_instructor_mode

self.aclient = instructor.from_litellm(
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
)

@observe(as_type="generation")
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
Expand Down Expand Up @@ -118,7 +123,7 @@ async def acreate_structured_output(
},
],
api_key=self.api_key,
max_retries=5,
max_retries=self.MAX_RETRIES,
api_base=self.endpoint,
api_version=self.api_version,
response_model=response_model,
Expand Down Expand Up @@ -152,7 +157,7 @@ async def acreate_structured_output(
"content": system_prompt,
},
],
max_retries=5,
max_retries=self.MAX_RETRIES,
api_key=self.fallback_api_key,
api_base=self.fallback_endpoint,
response_model=response_model,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Adapter for Generic API LLM provider API"""

import base64
import mimetypes
import litellm
import instructor
from typing import Type
from typing import Type, Optional
from pydantic import BaseModel
from openai import ContentFilterFinishReasonError
from litellm.exceptions import ContentPolicyViolationError
Expand All @@ -12,6 +14,8 @@
from cognee.infrastructure.llm.structured_output_framework.litellm_instructor.llm.llm_interface import (
LLMInterface,
)
from cognee.infrastructure.files.utils.open_data_file import open_data_file
from cognee.modules.observability.get_observe import get_observe
import logging
from cognee.shared.logging_utils import get_logger
from tenacity import (
Expand All @@ -23,6 +27,7 @@
)

logger = get_logger()
observe = get_observe()


class GenericAPIAdapter(LLMInterface):
Expand All @@ -38,18 +43,19 @@ class GenericAPIAdapter(LLMInterface):
Type[BaseModel]) -> BaseModel
"""

name: str
model: str
api_key: str
MAX_RETRIES = 5
default_instructor_mode = "json_mode"

def __init__(
self,
endpoint,
api_key: str,
model: str,
name: str,
max_completion_tokens: int,
name: str,
endpoint: str = None,
api_version: str = None,
transcription_model: str = None,
image_transcribe_model: str = None,
instructor_mode: str = None,
fallback_model: str = None,
fallback_api_key: str = None,
Expand All @@ -58,9 +64,11 @@ def __init__(
self.name = name
self.model = model
self.api_key = api_key
self.api_version = api_version
self.endpoint = endpoint
self.max_completion_tokens = max_completion_tokens

self.transcription_model = transcription_model or model
self.image_transcribe_model = image_transcribe_model or model
self.fallback_model = fallback_model
self.fallback_api_key = fallback_api_key
self.fallback_endpoint = fallback_endpoint
Expand All @@ -71,6 +79,7 @@ def __init__(
litellm.acompletion, mode=instructor.Mode(self.instructor_mode)
)

@observe(as_type="generation")
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
Expand Down Expand Up @@ -170,3 +179,112 @@ async def acreate_structured_output(
raise ContentPolicyFilterError(
f"The provided input contains content that is not aligned with our content policy: {text_input}"
) from error

@observe(as_type="transcription")
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def create_transcript(self, input) -> Optional[BaseModel]:
"""
Generate an audio transcript from a user query.
This method creates a transcript from the specified audio file, raising a
FileNotFoundError if the file does not exist. The audio file is processed and the
transcription is retrieved from the API.
Parameters:
-----------
- input: The path to the audio file that needs to be transcribed.
Returns:
--------
The generated transcription of the audio file.
"""
async with open_data_file(input, mode="rb") as audio_file:
encoded_string = base64.b64encode(audio_file.read()).decode("utf-8")
mime_type, _ = mimetypes.guess_type(input)
if not mime_type or not mime_type.startswith("audio/"):
raise ValueError(
f"Could not determine MIME type for audio file: {input}. Is the extension correct?"
)
return litellm.completion(
model=self.transcription_model,
messages=[
{
"role": "user",
"content": [
{
"type": "file",
"file": {"file_data": f"data:{mime_type};base64,{encoded_string}"},
},
{"type": "text", "text": "Transcribe the following audio precisely."},
],
}
],
api_key=self.api_key,
api_version=self.api_version,
max_completion_tokens=self.max_completion_tokens,
api_base=self.endpoint,
max_retries=self.MAX_RETRIES,
)

@observe(as_type="transcribe_image")
@retry(
stop=stop_after_delay(128),
wait=wait_exponential_jitter(2, 128),
retry=retry_if_not_exception_type(litellm.exceptions.NotFoundError),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
async def transcribe_image(self, input) -> Optional[BaseModel]:
"""
Generate a transcription of an image from a user query.
This method encodes the image and sends a request to the API to obtain a
description of the contents of the image.
Parameters:
-----------
- input: The path to the image file that needs to be transcribed.
Returns:
--------
- BaseModel: A structured output generated by the model, returned as an instance of
BaseModel.
"""
async with open_data_file(input, mode="rb") as image_file:
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
mime_type, _ = mimetypes.guess_type(input)
if not mime_type or not mime_type.startswith("image/"):
raise ValueError(
f"Could not determine MIME type for image file: {input}. Is the extension correct?"
)
return litellm.completion(
model=self.image_transcribe_model,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "What's in this image?",
},
{
"type": "image_url",
"image_url": {
"url": f"data:{mime_type};base64,{encoded_image}",
},
},
],
}
],
api_key=self.api_key,
api_base=self.endpoint,
api_version=self.api_version,
max_completion_tokens=300,
max_retries=self.MAX_RETRIES,
)
Loading