Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions crab/actions/file_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
import base64
from io import BytesIO

from PIL import Image

from crab.core import action

from pydantic import Field
from crab.core.decorators import action
from crab.utils.common import base64_to_image

@action
def save_base64_image(image: str, path: str = "image.png") -> None:
image = Image.open(BytesIO(base64.b64decode(image)))
image.save(path)
def save_image(image: str = Field(..., description="Base64 encoded image string"), path: str = Field(..., description="Path to save the image")):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider breaking the line to adhere to style guidelines.

The line defining the function signature is too long (145 characters), which exceeds the recommended limit of 88 characters. Consider breaking it into multiple lines for better readability.

Apply this diff to break the line:

-def save_image(image: str = Field(..., description="Base64 encoded image string"), path: str = Field(..., description="Path to save the image")):
+def save_image(
+    image: str = Field(..., description="Base64 encoded image string"),
+    path: str = Field(..., description="Path to save the image")
+):
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def save_image(image: str = Field(..., description="Base64 encoded image string"), path: str = Field(..., description="Path to save the image")):
def save_image(
image: str = Field(..., description="Base64 encoded image string"),
path: str = Field(..., description="Path to save the image")
):
Tools
Ruff

20-20: Line too long (145 > 88)

(E501)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the pydantic Field cannot be used as the annotations. What did you want to achieve by this modification?

"""Save a base64 encoded image to a file."""
img = base64_to_image(image)
img.save(path)
return f"Image saved to {path}"
158 changes: 158 additions & 0 deletions crab/agents/backend_models/glm_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========

import json
from typing import Any
from time import sleep

from crab import Action, ActionOutput, BackendModel, BackendOutput, MessageType
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unused import.

The import crab.BackendOutput is not used anywhere in the file.

Apply this diff to remove the unused import:

-from crab import Action, ActionOutput, BackendModel, BackendOutput, MessageType
+from crab import Action, ActionOutput, BackendModel, MessageType
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from crab import Action, ActionOutput, BackendModel, BackendOutput, MessageType
from crab import Action, ActionOutput, BackendModel, MessageType
Tools
Ruff

19-19: crab.BackendOutput imported but unused

Remove unused import: crab.BackendOutput

(F401)


try:
from zhipuai import ZhipuAI
glm_model_enable = True
except ImportError:
glm_model_enable = False

class GLMModel(BackendModel):
def __init__(
self,
model: str,
parameters: dict[str, Any] = dict(),
history_messages_len: int = 0,
) -> None:
if not glm_model_enable:
raise ImportError("Please install zhipuai to use GLMModel")
super().__init__(
model,
parameters,
history_messages_len,
)
self.client = ZhipuAI()

def reset(self, system_message: str, action_space: list[Action] | None) -> None:
self.system_message = system_message
self.glm_system_message = {
"role": "system",
"content": system_message,
}
self.action_space = action_space
self.action_schema = self._convert_action_to_schema(self.action_space)
self.token_usage = 0
self.chat_history = []

def chat(self, message: tuple[str, MessageType]):
request_messages = self._convert_to_request_messages(message)
response = self.call_api(request_messages)

assistant_message = response.choices[0].message
action_list = self._convert_tool_calls_to_action_list(assistant_message)

output = ChatOutput(
message=assistant_message.content if not action_list else None,
action_list=action_list,
)

self.record_message(request_messages[-1], assistant_message)
return output
Comment on lines +54 to +67
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chat interaction method.

This method handles the chat interaction by converting messages, calling the API, and processing the response. The implementation seems robust, but ensure that the ChatOutput class is defined as it is used here but not imported or defined in this file.

Define or import ChatOutput to resolve the undefined name error:

from crab import ChatOutput  # Assuming it's part of the crab module
Tools
Ruff

61-61: Undefined name ChatOutput

(F821)


def get_token_usage(self):
return self.token_usage

def record_message(self, new_message: dict, response_message: dict) -> None:
self.chat_history.append([new_message])
self.chat_history[-1].append(response_message)

if self.action_schema:
tool_calls = response_message.tool_calls
for tool_call in tool_calls:
self.chat_history[-1].append(
{
"tool_call_id": tool_call.id,
"role": "tool",
"name": tool_call.function.name,
"content": "",
}
)

def call_api(self, request_messages: list):
while True:
try:
response = self.client.chat.completions.create(
model=self.model,
messages=request_messages,
**self.parameters,
)
except Exception as e:
print(f"API call failed: {str(e)}. Retrying in 10 seconds...")
sleep(10)
else:
break

self.token_usage += response.usage.total_tokens
return response
Comment on lines +88 to +103
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

API call method with retry logic.

The method includes retry logic for handling API call failures, which is crucial for maintaining robustness in network interactions. However, consider adding a limit to the number of retries to avoid potential infinite loops.

Add a retry limit to the API call method:

max_retries = 5
attempts = 0
while attempts < max_retries:
    try:
        ...
    except Exception as e:
        if attempts < max_retries - 1:
            print(f"API call failed: {str(e)}. Retrying in 10 seconds...")
            sleep(10)
            attempts += 1
        else:
            raise Exception("API call failed after maximum retries.")
    else:
        break


@staticmethod
def _convert_action_to_schema(action_space: list[Action] | None):
if action_space is None:
return None

tools = []
for action in action_space:
tool = {
"type": "function",
"function": {
"name": action.name,
"description": action.description,
"parameters": {
"type": "object",
"properties": {},
"required": [],
},
},
}
for param in action.parameters:
tool["function"]["parameters"]["properties"][param.name] = {
"type": param.type,
"description": param.description,
}
if param.required:
tool["function"]["parameters"]["required"].append(param.name)
tools.append(tool)
return tools

@staticmethod
def _convert_tool_calls_to_action_list(self, message):
if not message.content or not message.content.startswith("arguments="):
return None

action_list = []
parts = message.content.split(", name=")
arguments = json.loads(parts[0].replace("arguments=", "").strip("'"))
name = parts[1].strip("'")
action_output = ActionOutput(
name=name,
args=arguments,
)
action_list.append(action_output)
return action_list

@staticmethod
def _convert_message(message: tuple[str, MessageType]):
content, message_type = message
if message_type == MessageType.TEXT:
return {"type": "text", "text": content}
elif message_type == MessageType.IMAGE_URL:
return {"type": "image_url", "image_url": {"url": content}}
else:
raise ValueError(f"Unsupported message type: {message_type}")
104 changes: 104 additions & 0 deletions test/core/test_image_handling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import pytest
from PIL import Image
import io
import base64
import os
Comment on lines +3 to +5
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unused imports.

The imports io, base64, and os are not used anywhere in the file. Removing these will clean up the code and avoid unnecessary dependencies.

Apply this diff to remove the unused imports:

-import io
-import base64
-import os
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
import io
import base64
import os
Tools
Ruff

3-3: io imported but unused

Remove unused import: io

(F401)


4-4: base64 imported but unused

Remove unused import: base64

(F401)


5-5: os imported but unused

Remove unused import: os

(F401)

from unittest.mock import patch, MagicMock
from crab.utils.common import base64_to_image, image_to_base64
from crab.actions.file_actions import save_image

import sys
# Mock the entire crab.agents.backend_models module
sys.modules['crab.agents.backend_models'] = MagicMock()
sys.modules['crab.agents.backend_models.openai_model'] = MagicMock()
sys.modules['crab.actions.desktop_actions'] = MagicMock()

# Create mock classes/functions
class MockOpenAIModel:
def _convert_message(self, message):
return {"type": "image_url", "image_url": {"url": "data:image/png;base64,mockbase64"}}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Address line length issues.

Several lines exceed the recommended line length of 88 characters. Refactoring these lines will improve code readability and maintainability.

Apply these diffs to address the line length issues:

-        return {"type": "image_url", "image_url": {"url": "data:image/png;base64,mockbase64"}}
+        return {
+            "type": "image_url",
+            "image_url": {"url": "data:image/png;base64,mockbase64"}
+        }

-        converted_message = openai_model._convert_message((received_image, MockMessageType.IMAGE_JPG_BASE64))
+        converted_message = openai_model._convert_message(
+            (received_image, MockMessageType.IMAGE_JPG_BASE64)
+        )

-        assert converted_message["image_url"]["url"].startswith("data:image/png;base64,")
+        assert converted_message["image_url"]["url"].startswith(
+            "data:image/png;base64,"
+        )

Also applies to: 77-77, 79-79

Tools
Ruff

19-19: Line too long (94 > 88)

(E501)


class MockMessageType:
IMAGE_JPG_BASE64 = "image_jpg_base64"

def mock_screenshot():
return Image.new('RGB', (100, 100), color='red')

# Apply mocks
patch('crab.agents.backend_models.openai_model.OpenAIModel', MockOpenAIModel).start()
patch('crab.agents.backend_models.openai_model.MessageType', MockMessageType).start()
patch('crab.actions.desktop_actions.screenshot', mock_screenshot).start()

class TestImageHandling:
@pytest.fixture(autouse=True)
def setup(self):
self.test_image = Image.new('RGB', (100, 100), color='red')

def test_image_processing_path(self):
print("\n--- Image Processing Path Test ---")

# Use self.test_image instead of taking a screenshot
screenshot_image = self.test_image

# 1. Start with a PIL Image (using self.test_image)
print("1. Starting with a PIL Image")
assert isinstance(self.test_image, Image.Image)

# 2. Simulate saving the image
print("2. Saving the image")
save_image(self.test_image, "test_image.png")
print(" Image saved successfully")

# 3. Using self.test_image instead of taking a screenshot
print("3. Using self.test_image instead of taking a screenshot")
screenshot_image = self.test_image
assert isinstance(screenshot_image, Image.Image)
print(" Using self.test_image as PIL Image")

# 4. Prepare for network transfer (serialize to base64)
print("4. Serializing image for network transfer")
base64_string = image_to_base64(self.test_image)
assert isinstance(base64_string, str)
print(" Image serialized to base64 string")

# 5. Simulate network transfer
print("5. Simulating network transfer")
received_base64 = base64_string # In reality, this would be sent and received

# 6. Deserialize after network transfer
print("6. Deserializing image after network transfer")
received_image = base64_to_image(received_base64)
assert isinstance(received_image, Image.Image)
print(" Image deserialized back to PIL Image")

# 7. Use the image in a backend model (e.g., OpenAI)
print("7. Using image in backend model")
openai_model = MockOpenAIModel()
converted_message = openai_model._convert_message((received_image, MockMessageType.IMAGE_JPG_BASE64))
assert converted_message["type"] == "image_url"
assert converted_message["image_url"]["url"].startswith("data:image/png;base64,")
print(" Image successfully converted for use in OpenAI model")

print("--- Image Processing Path Test Completed Successfully ---")

def test_base64_to_image(self):
# Convert image to base64
base64_string = image_to_base64(self.test_image)

# Test base64_to_image function
converted_image = base64_to_image(base64_string)
assert isinstance(converted_image, Image.Image)
assert converted_image.size == (100, 100)

def test_image_to_base64(self):
# Test image_to_base64 function
base64_string = image_to_base64(self.test_image)
assert isinstance(base64_string, str)

# Verify that the base64 string can be converted back to an image
converted_image = base64_to_image(base64_string)
assert converted_image.size == (100, 100)

# Make sure to stop all patches after the tests
Comment on lines +32 to +102
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comprehensive review of the TestImageHandling class and related functions.

The TestImageHandling class and its methods provide a structured approach to testing the image handling functionalities introduced in the PR. The use of mocks and patches is appropriate for isolating the tests from external dependencies. The tests are well-documented with print statements that describe each step, which is helpful for understanding the test flow but might be unnecessary in a production environment.

However, consider the following improvements:

  • Reduce verbosity: The print statements are useful for debugging but consider removing them or replacing them with logging statements that can be enabled or disabled based on the environment.
  • Enhance assertions: While the tests check for types and basic properties, consider adding more detailed assertions that check the content of the images or the exact base64 strings to ensure that the image processing is not only completing but also producing the correct results.

Consider replacing print statements with logging and enhancing assertions for more robust testing.

Example of replacing print with logging:

import logging
logging.basicConfig(level=logging.INFO)

# Replace print statements
logging.info("1. Starting with a PIL Image")

Example of enhanced assertions:

# After converting image to base64 and back
original_data = self.test_image.tobytes()
restored_data = converted_image.tobytes()
assert original_data == restored_data, "Image data must remain consistent after conversions."
Tools
Ruff

77-77: Line too long (109 > 88)

(E501)


79-79: Line too long (89 > 88)

(E501)

def teardown_module(module):
patch.stopall()