diff --git a/FirebaseAI/Sources/Chat.swift b/FirebaseAI/Sources/Chat.swift index 42da2ef4a6d..80e908a8f57 100644 --- a/FirebaseAI/Sources/Chat.swift +++ b/FirebaseAI/Sources/Chat.swift @@ -147,31 +147,48 @@ public final class Chat: Sendable { } private func aggregatedChunks(_ chunks: [ModelContent]) -> ModelContent { - var parts: [any Part] = [] + var parts: [InternalPart] = [] var combinedText = "" - for aggregate in chunks { - // Loop through all the parts, aggregating the text and adding the images. - for part in aggregate.parts { - switch part { - case let textPart as TextPart: - combinedText += textPart.text - - default: - // Don't combine it, just add to the content. If there's any text pending, add that as - // a part. + var combinedThoughts = "" + + func flush() { + if !combinedThoughts.isEmpty { + parts.append(InternalPart(.text(combinedThoughts), isThought: true, thoughtSignature: nil)) + combinedThoughts = "" + } + if !combinedText.isEmpty { + parts.append(InternalPart(.text(combinedText), isThought: nil, thoughtSignature: nil)) + combinedText = "" + } + } + + // Loop through all the parts, aggregating the text. + for part in chunks.flatMap({ $0.internalParts }) { + // Only text parts may be combined. + if case let .text(text) = part.data, part.thoughtSignature == nil { + // Thought summaries must not be combined with regular text. + if part.isThought ?? false { + // If we were combining regular text, flush it before handling "thoughts". if !combinedText.isEmpty { - parts.append(TextPart(combinedText)) - combinedText = "" + flush() } - - parts.append(part) + combinedThoughts += text + } else { + // If we were combining "thoughts", flush it before handling regular text. + if !combinedThoughts.isEmpty { + flush() + } + combinedText += text } + } else { + // This is a non-combinable part (not text), flush any pending text. + flush() + parts.append(part) } } - if !combinedText.isEmpty { - parts.append(TextPart(combinedText)) - } + // Flush any remaining text. + flush() return ModelContent(role: "model", parts: parts) } diff --git a/FirebaseAI/Sources/GenerateContentResponse.swift b/FirebaseAI/Sources/GenerateContentResponse.swift index 1cc9874e795..e1e7a79a686 100644 --- a/FirebaseAI/Sources/GenerateContentResponse.swift +++ b/FirebaseAI/Sources/GenerateContentResponse.swift @@ -58,29 +58,11 @@ public struct GenerateContentResponse: Sendable { /// The response's content as text, if it exists. public var text: String? { - guard let candidate = candidates.first else { - AILog.error( - code: .generateContentResponseNoCandidates, - "Could not get text from a response that had no candidates." - ) - return nil - } - let textValues: [String] = candidate.content.parts.compactMap { part in - switch part { - case let textPart as TextPart: - return textPart.text - default: - return nil - } - } - guard textValues.count > 0 else { - AILog.error( - code: .generateContentResponseNoText, - "Could not get a text part from the first candidate." - ) - return nil - } - return textValues.joined(separator: " ") + return text(isThought: false) + } + + public var thoughtSummary: String? { + return text(isThought: true) } /// Returns function calls found in any `Part`s of the first candidate of the response, if any. @@ -89,12 +71,10 @@ public struct GenerateContentResponse: Sendable { return [] } return candidate.content.parts.compactMap { part in - switch part { - case let functionCallPart as FunctionCallPart: - return functionCallPart - default: + guard let functionCallPart = part as? FunctionCallPart, !part.isThought else { return nil } + return functionCallPart } } @@ -107,7 +87,12 @@ public struct GenerateContentResponse: Sendable { """) return [] } - return candidate.content.parts.compactMap { $0 as? InlineDataPart } + return candidate.content.parts.compactMap { part in + guard let inlineDataPart = part as? InlineDataPart, !part.isThought else { + return nil + } + return inlineDataPart + } } /// Initializer for SwiftUI previews or tests. @@ -117,6 +102,30 @@ public struct GenerateContentResponse: Sendable { self.promptFeedback = promptFeedback self.usageMetadata = usageMetadata } + + func text(isThought: Bool) -> String? { + guard let candidate = candidates.first else { + AILog.error( + code: .generateContentResponseNoCandidates, + "Could not get text from a response that had no candidates." + ) + return nil + } + let textValues: [String] = candidate.content.parts.compactMap { part in + guard let textPart = part as? TextPart, part.isThought == isThought else { + return nil + } + return textPart.text + } + guard textValues.count > 0 else { + AILog.error( + code: .generateContentResponseNoText, + "Could not get a text part from the first candidate." + ) + return nil + } + return textValues.joined(separator: " ") + } } /// A struct representing a possible reply to a content generation prompt. Each content generation diff --git a/FirebaseAI/Sources/ModelContent.swift b/FirebaseAI/Sources/ModelContent.swift index 7d82bd76445..558c30b2789 100644 --- a/FirebaseAI/Sources/ModelContent.swift +++ b/FirebaseAI/Sources/ModelContent.swift @@ -31,19 +31,34 @@ extension [ModelContent] { } } -/// A type describing data in media formats interpretable by an AI model. Each generative AI -/// request or response contains an `Array` of ``ModelContent``s, and each ``ModelContent`` value -/// may comprise multiple heterogeneous ``Part``s. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -public struct ModelContent: Equatable, Sendable { - enum InternalPart: Equatable, Sendable { +struct InternalPart: Equatable, Sendable { + enum OneOfData: Equatable, Sendable { case text(String) - case inlineData(mimetype: String, Data) - case fileData(mimetype: String, uri: String) + case inlineData(InlineData) + case fileData(FileData) case functionCall(FunctionCall) case functionResponse(FunctionResponse) } + let data: OneOfData + + let isThought: Bool? + + let thoughtSignature: String? + + init(_ data: OneOfData, isThought: Bool?, thoughtSignature: String?) { + self.data = data + self.isThought = isThought + self.thoughtSignature = thoughtSignature + } +} + +/// A type describing data in media formats interpretable by an AI model. Each generative AI +/// request or response contains an `Array` of ``ModelContent``s, and each ``ModelContent`` value +/// may comprise multiple heterogeneous ``Part``s. +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +public struct ModelContent: Equatable, Sendable { /// The role of the entity creating the ``ModelContent``. For user-generated client requests, /// for example, the role is `user`. public let role: String? @@ -52,17 +67,29 @@ public struct ModelContent: Equatable, Sendable { public var parts: [any Part] { var convertedParts = [any Part]() for part in internalParts { - switch part { + switch part.data { case let .text(text): - convertedParts.append(TextPart(text)) - case let .inlineData(mimetype, data): - convertedParts.append(InlineDataPart(data: data, mimeType: mimetype)) - case let .fileData(mimetype, uri): - convertedParts.append(FileDataPart(uri: uri, mimeType: mimetype)) + convertedParts.append( + TextPart(text, isThought: part.isThought, thoughtSignature: part.thoughtSignature) + ) + case let .inlineData(inlineData): + convertedParts.append(InlineDataPart( + inlineData, isThought: part.isThought, thoughtSignature: part.thoughtSignature + )) + case let .fileData(fileData): + convertedParts.append(FileDataPart( + fileData, + isThought: part.isThought, + thoughtSignature: part.thoughtSignature + )) case let .functionCall(functionCall): - convertedParts.append(FunctionCallPart(functionCall)) + convertedParts.append(FunctionCallPart( + functionCall, isThought: part.isThought, thoughtSignature: part.thoughtSignature + )) case let .functionResponse(functionResponse): - convertedParts.append(FunctionResponsePart(functionResponse)) + convertedParts.append(FunctionResponsePart( + functionResponse, isThought: part.isThought, thoughtSignature: part.thoughtSignature + )) } } return convertedParts @@ -78,17 +105,35 @@ public struct ModelContent: Equatable, Sendable { for part in parts { switch part { case let textPart as TextPart: - convertedParts.append(.text(textPart.text)) + convertedParts.append(InternalPart( + .text(textPart.text), + isThought: textPart._isThought, + thoughtSignature: textPart.thoughtSignature + )) case let inlineDataPart as InlineDataPart: - let inlineData = inlineDataPart.inlineData - convertedParts.append(.inlineData(mimetype: inlineData.mimeType, inlineData.data)) + convertedParts.append(InternalPart( + .inlineData(inlineDataPart.inlineData), + isThought: inlineDataPart._isThought, + thoughtSignature: inlineDataPart.thoughtSignature + )) case let fileDataPart as FileDataPart: - let fileData = fileDataPart.fileData - convertedParts.append(.fileData(mimetype: fileData.mimeType, uri: fileData.fileURI)) + convertedParts.append(InternalPart( + .fileData(fileDataPart.fileData), + isThought: fileDataPart._isThought, + thoughtSignature: fileDataPart.thoughtSignature + )) case let functionCallPart as FunctionCallPart: - convertedParts.append(.functionCall(functionCallPart.functionCall)) + convertedParts.append(InternalPart( + .functionCall(functionCallPart.functionCall), + isThought: functionCallPart._isThought, + thoughtSignature: functionCallPart.thoughtSignature + )) case let functionResponsePart as FunctionResponsePart: - convertedParts.append(.functionResponse(functionResponsePart.functionResponse)) + convertedParts.append(InternalPart( + .functionResponse(functionResponsePart.functionResponse), + isThought: functionResponsePart._isThought, + thoughtSignature: functionResponsePart.thoughtSignature + )) default: fatalError() } @@ -102,6 +147,11 @@ public struct ModelContent: Equatable, Sendable { let content = parts.flatMap { $0.partsValue } self.init(role: role, parts: content) } + + init(role: String?, parts: [InternalPart]) { + self.role = role + internalParts = parts + } } // MARK: Codable Conformances @@ -121,7 +171,29 @@ extension ModelContent: Codable { } @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -extension ModelContent.InternalPart: Codable { +extension InternalPart: Codable { + enum CodingKeys: String, CodingKey { + case isThought = "thought" + case thoughtSignature + } + + public func encode(to encoder: Encoder) throws { + try data.encode(to: encoder) + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encodeIfPresent(isThought, forKey: .isThought) + try container.encodeIfPresent(thoughtSignature, forKey: .thoughtSignature) + } + + public init(from decoder: Decoder) throws { + data = try OneOfData(from: decoder) + let container = try decoder.container(keyedBy: CodingKeys.self) + isThought = try container.decodeIfPresent(Bool.self, forKey: .isThought) + thoughtSignature = try container.decodeIfPresent(String.self, forKey: .thoughtSignature) + } +} + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +extension InternalPart.OneOfData: Codable { enum CodingKeys: String, CodingKey { case text case inlineData @@ -135,10 +207,10 @@ extension ModelContent.InternalPart: Codable { switch self { case let .text(text): try container.encode(text, forKey: .text) - case let .inlineData(mimetype, bytes): - try container.encode(InlineData(data: bytes, mimeType: mimetype), forKey: .inlineData) - case let .fileData(mimetype: mimetype, url): - try container.encode(FileData(fileURI: url, mimeType: mimetype), forKey: .fileData) + case let .inlineData(inlineData): + try container.encode(inlineData, forKey: .inlineData) + case let .fileData(fileData): + try container.encode(fileData, forKey: .fileData) case let .functionCall(functionCall): try container.encode(functionCall, forKey: .functionCall) case let .functionResponse(functionResponse): @@ -151,11 +223,9 @@ extension ModelContent.InternalPart: Codable { if values.contains(.text) { self = try .text(values.decode(String.self, forKey: .text)) } else if values.contains(.inlineData) { - let inlineData = try values.decode(InlineData.self, forKey: .inlineData) - self = .inlineData(mimetype: inlineData.mimeType, inlineData.data) + self = try .inlineData(values.decode(InlineData.self, forKey: .inlineData)) } else if values.contains(.fileData) { - let fileData = try values.decode(FileData.self, forKey: .fileData) - self = .fileData(mimetype: fileData.mimeType, uri: fileData.fileURI) + self = try .fileData(values.decode(FileData.self, forKey: .fileData)) } else if values.contains(.functionCall) { self = try .functionCall(values.decode(FunctionCall.self, forKey: .functionCall)) } else if values.contains(.functionResponse) { diff --git a/FirebaseAI/Sources/Types/Internal/InternalPart.swift b/FirebaseAI/Sources/Types/Internal/InternalPart.swift index d543fb80f38..bb62dd4c0b5 100644 --- a/FirebaseAI/Sources/Types/Internal/InternalPart.swift +++ b/FirebaseAI/Sources/Types/Internal/InternalPart.swift @@ -67,6 +67,9 @@ struct FunctionResponse: Codable, Equatable, Sendable { struct ErrorPart: Part, Error { let error: Error + let isThought = false + let thoughtSignature: String? = nil + init(_ error: Error) { self.error = error } diff --git a/FirebaseAI/Sources/Types/Public/Part.swift b/FirebaseAI/Sources/Types/Public/Part.swift index 4890b725f4d..fb743d1025d 100644 --- a/FirebaseAI/Sources/Types/Public/Part.swift +++ b/FirebaseAI/Sources/Types/Public/Part.swift @@ -18,7 +18,14 @@ import Foundation /// /// Within a single value of ``Part``, different data types may not mix. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -public protocol Part: PartsRepresentable, Codable, Sendable, Equatable {} +public protocol Part: PartsRepresentable, Codable, Sendable, Equatable { + /// Indicates whether this `Part` is a summary of the model's internal thinking process. + /// + /// When `includeThoughts` is set to `true` in ``ThinkingConfig``, the model may return one or + /// more "thought" parts that provide insight into how it reasoned through the prompt to arrive + /// at the final answer. These parts will have `isThought` set to `true`. + var isThought: Bool { get } +} /// A text part containing a string value. @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) @@ -26,8 +33,20 @@ public struct TextPart: Part { /// Text value. public let text: String + public var isThought: Bool { _isThought ?? false } + + let thoughtSignature: String? + + let _isThought: Bool? + public init(_ text: String) { + self.init(text, isThought: nil, thoughtSignature: nil) + } + + init(_ text: String, isThought: Bool?, thoughtSignature: String?) { self.text = text + _isThought = isThought + self.thoughtSignature = thoughtSignature } } @@ -45,6 +64,7 @@ public struct TextPart: Part { @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) public struct InlineDataPart: Part { let inlineData: InlineData + let _isThought: Bool? /// The data provided in the inline data part. public var data: Data { inlineData.data } @@ -52,6 +72,10 @@ public struct InlineDataPart: Part { /// The IANA standard MIME type of the data. public var mimeType: String { inlineData.mimeType } + public var isThought: Bool { _isThought ?? false } + + let thoughtSignature: String? + /// Creates an inline data part from data and a MIME type. /// /// > Important: Supported input types depend on the model on the model being used; see [input @@ -67,11 +91,13 @@ public struct InlineDataPart: Part { /// requirements](https://firebase.google.com/docs/vertex-ai/input-file-requirements) for /// supported values. public init(data: Data, mimeType: String) { - self.init(InlineData(data: data, mimeType: mimeType)) + self.init(InlineData(data: data, mimeType: mimeType), isThought: nil, thoughtSignature: nil) } - init(_ inlineData: InlineData) { + init(_ inlineData: InlineData, isThought: Bool?, thoughtSignature: String?) { self.inlineData = inlineData + _isThought = isThought + self.thoughtSignature = thoughtSignature } } @@ -79,9 +105,12 @@ public struct InlineDataPart: Part { @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) public struct FileDataPart: Part { let fileData: FileData + let _isThought: Bool? + let thoughtSignature: String? public var uri: String { fileData.fileURI } public var mimeType: String { fileData.mimeType } + public var isThought: Bool { _isThought ?? false } /// Constructs a new file data part. /// @@ -93,11 +122,13 @@ public struct FileDataPart: Part { /// requirements](https://firebase.google.com/docs/vertex-ai/input-file-requirements) for /// supported values. public init(uri: String, mimeType: String) { - self.init(FileData(fileURI: uri, mimeType: mimeType)) + self.init(FileData(fileURI: uri, mimeType: mimeType), isThought: nil, thoughtSignature: nil) } - init(_ fileData: FileData) { + init(_ fileData: FileData, isThought: Bool?, thoughtSignature: String?) { self.fileData = fileData + _isThought = isThought + self.thoughtSignature = thoughtSignature } } @@ -105,6 +136,8 @@ public struct FileDataPart: Part { @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) public struct FunctionCallPart: Part { let functionCall: FunctionCall + let _isThought: Bool? + let thoughtSignature: String? /// The name of the function to call. public var name: String { functionCall.name } @@ -112,6 +145,8 @@ public struct FunctionCallPart: Part { /// The function parameters and values. public var args: JSONObject { functionCall.args } + public var isThought: Bool { _isThought ?? false } + /// Constructs a new function call part. /// /// > Note: A `FunctionCallPart` is typically received from the model, rather than created @@ -121,11 +156,13 @@ public struct FunctionCallPart: Part { /// - name: The name of the function to call. /// - args: The function parameters and values. public init(name: String, args: JSONObject) { - self.init(FunctionCall(name: name, args: args)) + self.init(FunctionCall(name: name, args: args), isThought: nil, thoughtSignature: nil) } - init(_ functionCall: FunctionCall) { + init(_ functionCall: FunctionCall, isThought: Bool?, thoughtSignature: String?) { self.functionCall = functionCall + _isThought = isThought + self.thoughtSignature = thoughtSignature } } @@ -137,6 +174,8 @@ public struct FunctionCallPart: Part { @available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) public struct FunctionResponsePart: Part { let functionResponse: FunctionResponse + let _isThought: Bool? + let thoughtSignature: String? /// The name of the function that was called. public var name: String { functionResponse.name } @@ -144,16 +183,22 @@ public struct FunctionResponsePart: Part { /// The function's response or return value. public var response: JSONObject { functionResponse.response } + public var isThought: Bool { _isThought ?? false } + /// Constructs a new `FunctionResponse`. /// /// - Parameters: /// - name: The name of the function that was called. /// - response: The function's response. public init(name: String, response: JSONObject) { - self.init(FunctionResponse(name: name, response: response)) + self.init( + FunctionResponse(name: name, response: response), isThought: nil, thoughtSignature: nil + ) } - init(_ functionResponse: FunctionResponse) { + init(_ functionResponse: FunctionResponse, isThought: Bool?, thoughtSignature: String?) { self.functionResponse = functionResponse + _isThought = isThought + self.thoughtSignature = thoughtSignature } } diff --git a/FirebaseAI/Sources/Types/Public/ThinkingConfig.swift b/FirebaseAI/Sources/Types/Public/ThinkingConfig.swift index c0e8f31465b..a339f8fa1d1 100644 --- a/FirebaseAI/Sources/Types/Public/ThinkingConfig.swift +++ b/FirebaseAI/Sources/Types/Public/ThinkingConfig.swift @@ -37,12 +37,24 @@ public struct ThinkingConfig: Sendable { /// feature or if the specified budget is not within the model's supported range. let thinkingBudget: Int? + /// Whether summaries of the model's "thoughts" are included in responses. + /// + /// When `includeThoughts` is set to `true`, the model will return a summary of its internal + /// thinking process alongside the final answer. This can provide valuable insight into how the + /// model arrived at its conclusion, which is particularly useful for complex or creative tasks. + /// + /// If you don't specify a value for `includeThoughts` (`nil`), the model will use its default + /// behavior (which is typically to not include thought summaries). + let includeThoughts: Bool? + /// Initializes a new `ThinkingConfig`. /// /// - Parameters: /// - thinkingBudget: The maximum number of tokens to be used for the model's thinking process. - public init(thinkingBudget: Int? = nil) { + /// - includeThoughts: If true, summaries of the model's "thoughts" are included in responses. + public init(thinkingBudget: Int? = nil, includeThoughts: Bool? = nil) { self.thinkingBudget = thinkingBudget + self.includeThoughts = includeThoughts } } diff --git a/FirebaseAI/Tests/TestApp/Tests/Integration/GenerateContentIntegrationTests.swift b/FirebaseAI/Tests/TestApp/Tests/Integration/GenerateContentIntegrationTests.swift index d83c300623d..c8e452c7dd9 100644 --- a/FirebaseAI/Tests/TestApp/Tests/Integration/GenerateContentIntegrationTests.swift +++ b/FirebaseAI/Tests/TestApp/Tests/Integration/GenerateContentIntegrationTests.swift @@ -134,47 +134,83 @@ struct GenerateContentIntegrationTests { #expect(candidatesTokensDetails.tokenCount == usageMetadata.candidatesTokenCount) } - @Test(arguments: [ - (InstanceConfig.vertexAI_v1beta, ModelNames.gemini2_5_Flash, 0), - (InstanceConfig.vertexAI_v1beta, ModelNames.gemini2_5_Flash, 24576), - (InstanceConfig.vertexAI_v1beta_global, ModelNames.gemini2_5_Pro, 128), - (InstanceConfig.vertexAI_v1beta_global, ModelNames.gemini2_5_Pro, 32768), - (InstanceConfig.googleAI_v1beta, ModelNames.gemini2_5_Flash, 0), - (InstanceConfig.googleAI_v1beta, ModelNames.gemini2_5_Flash, 24576), - (InstanceConfig.googleAI_v1beta, ModelNames.gemini2_5_Pro, 128), - (InstanceConfig.googleAI_v1beta, ModelNames.gemini2_5_Pro, 32768), - (InstanceConfig.googleAI_v1beta_freeTier, ModelNames.gemini2_5_Flash, 0), - (InstanceConfig.googleAI_v1beta_freeTier, ModelNames.gemini2_5_Flash, 24576), - ]) + @Test( + arguments: [ + (.vertexAI_v1beta, ModelNames.gemini2_5_Flash, ThinkingConfig(thinkingBudget: 0)), + (.vertexAI_v1beta, ModelNames.gemini2_5_Flash, ThinkingConfig(thinkingBudget: 24576)), + (.vertexAI_v1beta, ModelNames.gemini2_5_Flash, ThinkingConfig( + thinkingBudget: 24576, includeThoughts: true + )), + (.vertexAI_v1beta_global, ModelNames.gemini2_5_Pro, ThinkingConfig(thinkingBudget: 128)), + (.vertexAI_v1beta_global, ModelNames.gemini2_5_Pro, ThinkingConfig(thinkingBudget: 32768)), + (.vertexAI_v1beta_global, ModelNames.gemini2_5_Pro, ThinkingConfig( + thinkingBudget: 32768, includeThoughts: true + )), + (.googleAI_v1beta, ModelNames.gemini2_5_Flash, ThinkingConfig(thinkingBudget: 0)), + (.googleAI_v1beta, ModelNames.gemini2_5_Flash, ThinkingConfig(thinkingBudget: 24576)), + (.googleAI_v1beta, ModelNames.gemini2_5_Flash, ThinkingConfig( + thinkingBudget: 24576, includeThoughts: true + )), + (.googleAI_v1beta, ModelNames.gemini2_5_Pro, ThinkingConfig(thinkingBudget: 128)), + (.googleAI_v1beta, ModelNames.gemini2_5_Pro, ThinkingConfig(thinkingBudget: 32768)), + (.googleAI_v1beta, ModelNames.gemini2_5_Pro, ThinkingConfig( + thinkingBudget: 32768, includeThoughts: true + )), + (.googleAI_v1beta_freeTier, ModelNames.gemini2_5_Flash, ThinkingConfig(thinkingBudget: 0)), + ( + .googleAI_v1beta_freeTier, + ModelNames.gemini2_5_Flash, + ThinkingConfig(thinkingBudget: 24576) + ), + (.googleAI_v1beta_freeTier, ModelNames.gemini2_5_Flash, ThinkingConfig( + thinkingBudget: 24576, includeThoughts: true + )), + (.googleAI_v1beta_freeTier_bypassProxy, ModelNames.gemini2_5_Flash, ThinkingConfig( + thinkingBudget: 0 + )), + (.googleAI_v1beta_freeTier_bypassProxy, ModelNames.gemini2_5_Flash, ThinkingConfig( + thinkingBudget: 24576 + )), + (.googleAI_v1beta_freeTier_bypassProxy, ModelNames.gemini2_5_Flash, ThinkingConfig( + thinkingBudget: 24576, includeThoughts: true + )), + ] as [(InstanceConfig, String, ThinkingConfig)] + ) func generateContentThinking(_ config: InstanceConfig, modelName: String, - thinkingBudget: Int) async throws { + thinkingConfig: ThinkingConfig) async throws { let model = FirebaseAI.componentInstance(config).generativeModel( modelName: modelName, generationConfig: GenerationConfig( temperature: 0.0, topP: 0.0, topK: 1, - thinkingConfig: ThinkingConfig(thinkingBudget: thinkingBudget) + thinkingConfig: thinkingConfig ), safetySettings: safetySettings ) + let chat = model.startChat() let prompt = "Where is Google headquarters located? Answer with the city name only." - let response = try await model.generateContent(prompt) + let response = try await chat.sendMessage(prompt) let text = try #require(response.text).trimmingCharacters(in: .whitespacesAndNewlines) #expect(text == "Mountain View") + let candidate = try #require(response.candidates.first) + let thoughtParts = candidate.content.parts.compactMap { $0.isThought ? $0 : nil } + #expect(thoughtParts.isEmpty != thinkingConfig.includeThoughts) + let usageMetadata = try #require(response.usageMetadata) #expect(usageMetadata.promptTokenCount.isEqual(to: 13, accuracy: tokenCountAccuracy)) #expect(usageMetadata.promptTokensDetails.count == 1) let promptTokensDetails = try #require(usageMetadata.promptTokensDetails.first) #expect(promptTokensDetails.modality == .text) #expect(promptTokensDetails.tokenCount == usageMetadata.promptTokenCount) - if thinkingBudget == 0 { - #expect(usageMetadata.thoughtsTokenCount == 0) - } else { + if let thinkingBudget = thinkingConfig.thinkingBudget, thinkingBudget > 0 { + #expect(usageMetadata.thoughtsTokenCount > 0) #expect(usageMetadata.thoughtsTokenCount <= thinkingBudget) + } else { + #expect(usageMetadata.thoughtsTokenCount == 0) } #expect(usageMetadata.candidatesTokenCount.isEqual(to: 3, accuracy: tokenCountAccuracy)) // The `candidatesTokensDetails` field is erroneously omitted when using the Google AI (Gemini @@ -195,6 +231,94 @@ struct GenerateContentIntegrationTests { )) } + @Test( + arguments: [ + (.vertexAI_v1beta_global, ModelNames.gemini2_5_Flash, ThinkingConfig(thinkingBudget: -1)), + (.vertexAI_v1beta_global, ModelNames.gemini2_5_Flash, ThinkingConfig( + thinkingBudget: -1, includeThoughts: true + )), + (.vertexAI_v1beta_global, ModelNames.gemini2_5_Pro, ThinkingConfig(thinkingBudget: -1)), + (.vertexAI_v1beta_global, ModelNames.gemini2_5_Pro, ThinkingConfig( + thinkingBudget: -1, includeThoughts: true + )), + (.googleAI_v1beta, ModelNames.gemini2_5_Flash, ThinkingConfig(thinkingBudget: -1)), + (.googleAI_v1beta, ModelNames.gemini2_5_Flash, ThinkingConfig( + thinkingBudget: -1, includeThoughts: true + )), + (.googleAI_v1beta, ModelNames.gemini2_5_Pro, ThinkingConfig(thinkingBudget: -1)), + (.googleAI_v1beta, ModelNames.gemini2_5_Pro, ThinkingConfig( + thinkingBudget: -1, includeThoughts: true + )), + ] as [(InstanceConfig, String, ThinkingConfig)] + ) + func generateContentThinkingFunctionCalling(_ config: InstanceConfig, modelName: String, + thinkingConfig: ThinkingConfig) async throws { + let getTemperatureDeclaration = FunctionDeclaration( + name: "getTemperature", + description: "Returns the current temperature in Celsius for the specified location", + parameters: [ + "city": .string(), + "region": .string(description: "The province or state"), + "country": .string(), + ] + ) + let model = FirebaseAI.componentInstance(config).generativeModel( + modelName: modelName, + generationConfig: GenerationConfig( + temperature: 0.0, + topP: 0.0, + topK: 1, + thinkingConfig: thinkingConfig + ), + safetySettings: safetySettings, + tools: [.functionDeclarations([getTemperatureDeclaration])], + systemInstruction: ModelContent(parts: """ + You are a weather bot that specializes in reporting outdoor temperatures in Celsius. + + Always use the `getTemperature` function to determine the current temperature in a location. + + Always respond in the format: + - Location: City, Province/State, Country + - Temperature: #C + """) + ) + let chat = model.startChat() + let prompt = "What is the current temperature in Waterloo, Ontario, Canada?" + + let response = try await chat.sendMessage(prompt) + + #expect(response.functionCalls.count == 1) + let temperatureFunctionCall = try #require(response.functionCalls.first) + try #require(temperatureFunctionCall.name == getTemperatureDeclaration.name) + #expect(temperatureFunctionCall.args == [ + "city": .string("Waterloo"), + "region": .string("Ontario"), + "country": .string("Canada"), + ]) + #expect(temperatureFunctionCall.isThought == false) + if let _ = thinkingConfig.includeThoughts, case .googleAI = config.apiConfig.service { + let thoughtSignature = try #require(temperatureFunctionCall.thoughtSignature) + #expect(!thoughtSignature.isEmpty) + } else { + #expect(temperatureFunctionCall.thoughtSignature == nil) + } + + let temperatureFunctionResponse = FunctionResponsePart( + name: temperatureFunctionCall.name, + response: [ + "temperature": .number(25), + "units": .string("Celsius"), + ] + ) + + let response2 = try await chat.sendMessage(temperatureFunctionResponse) + + #expect(response2.functionCalls.isEmpty) + let finalText = try #require(response2.text).trimmingCharacters(in: .whitespacesAndNewlines) + #expect(finalText.contains("Waterloo")) + #expect(finalText.contains("25")) + } + @Test(arguments: [ InstanceConfig.vertexAI_v1beta, InstanceConfig.vertexAI_v1beta_global, diff --git a/FirebaseAI/Tests/Unit/GenerativeModelGoogleAITests.swift b/FirebaseAI/Tests/Unit/GenerativeModelGoogleAITests.swift index 103943e6f92..00e0d398855 100644 --- a/FirebaseAI/Tests/Unit/GenerativeModelGoogleAITests.swift +++ b/FirebaseAI/Tests/Unit/GenerativeModelGoogleAITests.swift @@ -262,6 +262,52 @@ final class GenerativeModelGoogleAITests: XCTestCase { ) } + func testGenerateContent_success_thinking_thoughtSummary() async throws { + MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler( + forResource: "unary-success-thinking-reply-thought-summary", + withExtension: "json", + subdirectory: googleAISubdirectory + ) + + let response = try await model.generateContent(testPrompt) + + XCTAssertEqual(response.candidates.count, 1) + let candidate = try XCTUnwrap(response.candidates.first) + XCTAssertEqual(candidate.content.parts.count, 2) + let thoughtPart = try XCTUnwrap(candidate.content.parts.first as? TextPart) + XCTAssertTrue(thoughtPart.isThought) + XCTAssertTrue(thoughtPart.text.hasPrefix("**Thinking About Google's Headquarters**")) + XCTAssertEqual(thoughtPart.text, response.thoughtSummary) + let textPart = try XCTUnwrap(candidate.content.parts.last as? TextPart) + XCTAssertFalse(textPart.isThought) + XCTAssertEqual(textPart.text, "Mountain View") + XCTAssertEqual(textPart.text, response.text) + } + + func testGenerateContent_success_thinking_functionCall_thoughtSummaryAndSignature() async throws { + MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler( + forResource: "unary-success-thinking-function-call-thought-summary-signature", + withExtension: "json", + subdirectory: googleAISubdirectory + ) + + let response = try await model.generateContent(testPrompt) + + XCTAssertEqual(response.candidates.count, 1) + let candidate = try XCTUnwrap(response.candidates.first) + XCTAssertEqual(candidate.finishReason, .stop) + XCTAssertEqual(candidate.content.parts.count, 2) + let thoughtPart = try XCTUnwrap(candidate.content.parts.first as? TextPart) + XCTAssertTrue(thoughtPart.isThought) + XCTAssertTrue(thoughtPart.text.hasPrefix("**Thinking Through the New Year's Eve Calculation**")) + let functionCallPart = try XCTUnwrap(candidate.content.parts.last as? FunctionCallPart) + XCTAssertFalse(functionCallPart.isThought) + XCTAssertEqual(functionCallPart.name, "now") + XCTAssertTrue(functionCallPart.args.isEmpty) + let thoughtSignature = try XCTUnwrap(functionCallPart.thoughtSignature) + XCTAssertTrue(thoughtSignature.hasPrefix("CtQOAVSoXO74PmYr9AFu")) + } + func testGenerateContent_failure_invalidAPIKey() async throws { let expectedStatusCode = 400 MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler( @@ -397,6 +443,72 @@ final class GenerativeModelGoogleAITests: XCTestCase { XCTAssertNil(citation.publicationDate) } + func testGenerateContentStream_successWithThoughtSummary() async throws { + MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler( + forResource: "streaming-success-thinking-reply-thought-summary", + withExtension: "txt", + subdirectory: googleAISubdirectory + ) + + var thoughtSummary = "" + var text = "" + let stream = try model.generateContentStream("Hi") + for try await response in stream { + let candidate = try XCTUnwrap(response.candidates.first) + XCTAssertEqual(candidate.content.parts.count, 1) + let textPart = try XCTUnwrap(candidate.content.parts.first as? TextPart) + if textPart.isThought { + let newThought = try XCTUnwrap(response.thoughtSummary) + XCTAssertEqual(textPart.text, newThought) + thoughtSummary.append(newThought) + } else { + let newText = try XCTUnwrap(response.text) + XCTAssertEqual(textPart.text, newText) + text.append(newText) + } + } + + XCTAssertTrue(thoughtSummary.hasPrefix("**Exploring Sky Color**")) + XCTAssertTrue(text.hasPrefix("The sky is blue because")) + } + + func testGenerateContentStream_success_thinking_functionCall_thoughtSummary_signature() async throws { + MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler( + forResource: "streaming-success-thinking-function-call-thought-summary-signature", + withExtension: "txt", + subdirectory: googleAISubdirectory + ) + + var thoughtSummary = "" + var functionCalls: [FunctionCallPart] = [] + let stream = try model.generateContentStream("Hi") + for try await response in stream { + let candidate = try XCTUnwrap(response.candidates.first) + XCTAssertEqual(candidate.content.parts.count, 1) + let part = try XCTUnwrap(candidate.content.parts.first) + if part.isThought { + let textPart = try XCTUnwrap(part as? TextPart) + let newThought = try XCTUnwrap(response.thoughtSummary) + XCTAssertEqual(textPart.text, newThought) + thoughtSummary.append(newThought) + } else { + let functionCallPart = try XCTUnwrap(part as? FunctionCallPart) + XCTAssertEqual(response.functionCalls.count, 1) + let newFunctionCall = try XCTUnwrap(response.functionCalls.first) + XCTAssertEqual(functionCallPart, newFunctionCall) + functionCalls.append(newFunctionCall) + } + } + + XCTAssertTrue(thoughtSummary.hasPrefix("**Calculating the Days**")) + XCTAssertEqual(functionCalls.count, 1) + let functionCall = try XCTUnwrap(functionCalls.first) + XCTAssertEqual(functionCall.name, "now") + XCTAssertTrue(functionCall.args.isEmpty) + let thoughtSignature = try XCTUnwrap(functionCall.thoughtSignature) + XCTAssertTrue(thoughtSignature.hasPrefix("CiIBVKhc7vB+vaaq6rA")) + } + func testGenerateContentStream_failureInvalidAPIKey() async throws { MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler( forResource: "unary-failure-api-key", diff --git a/FirebaseAI/Tests/Unit/GenerativeModelVertexAITests.swift b/FirebaseAI/Tests/Unit/GenerativeModelVertexAITests.swift index 6557735ccc4..7c23726f152 100644 --- a/FirebaseAI/Tests/Unit/GenerativeModelVertexAITests.swift +++ b/FirebaseAI/Tests/Unit/GenerativeModelVertexAITests.swift @@ -434,6 +434,29 @@ final class GenerativeModelVertexAITests: XCTestCase { XCTAssertEqual(text, "The sum of [1, 2, 3] is") } + func testGenerateContent_success_thinking_thoughtSummary() async throws { + MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler( + forResource: "unary-success-thinking-reply-thought-summary", + withExtension: "json", + subdirectory: vertexSubdirectory + ) + + let response = try await model.generateContent(testPrompt) + + XCTAssertEqual(response.candidates.count, 1) + let candidate = try XCTUnwrap(response.candidates.first) + XCTAssertEqual(candidate.finishReason, .stop) + XCTAssertEqual(candidate.content.parts.count, 2) + let thoughtPart = try XCTUnwrap(candidate.content.parts.first as? TextPart) + XCTAssertTrue(thoughtPart.isThought) + XCTAssertTrue(thoughtPart.text.hasPrefix("Right, someone needs the city where Google")) + XCTAssertEqual(response.thoughtSummary, thoughtPart.text) + let textPart = try XCTUnwrap(candidate.content.parts.last as? TextPart) + XCTAssertFalse(textPart.isThought) + XCTAssertEqual(textPart.text, "Mountain View") + XCTAssertEqual(response.text, textPart.text) + } + func testGenerateContent_success_image_invalidSafetyRatingsIgnored() async throws { MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler( forResource: "unary-success-image-invalid-safety-ratings", @@ -1330,6 +1353,33 @@ final class GenerativeModelVertexAITests: XCTestCase { XCTAssertFalse(citations.contains { $0.license?.isEmpty ?? false }) } + func testGenerateContentStream_successWithThinking_thoughtSummary() async throws { + MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler( + forResource: "streaming-success-thinking-reply-thought-summary", + withExtension: "txt", + subdirectory: vertexSubdirectory + ) + + var thoughtSummary = "" + var text = "" + let stream = try model.generateContentStream("Hi") + for try await response in stream { + let candidate = try XCTUnwrap(response.candidates.first) + XCTAssertEqual(candidate.content.parts.count, 1) + let part = try XCTUnwrap(candidate.content.parts.first) + let textPart = try XCTUnwrap(part as? TextPart) + if textPart.isThought { + let newThought = try XCTUnwrap(response.thoughtSummary) + thoughtSummary.append(newThought) + } else { + text.append(textPart.text) + } + } + + XCTAssertTrue(thoughtSummary.hasPrefix("**Understanding the Core Question**")) + XCTAssertTrue(text.hasPrefix("The sky is blue due to a phenomenon")) + } + func testGenerateContentStream_successWithInvalidSafetyRatingsIgnored() async throws { MockURLProtocol.requestHandler = try GenerativeModelTestUtil.httpRequestHandler( forResource: "streaming-success-image-invalid-safety-ratings", diff --git a/FirebaseAI/Tests/Unit/Types/InternalPartTests.swift b/FirebaseAI/Tests/Unit/Types/InternalPartTests.swift new file mode 100644 index 00000000000..2cd5c5fee2a --- /dev/null +++ b/FirebaseAI/Tests/Unit/Types/InternalPartTests.swift @@ -0,0 +1,286 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +@testable import FirebaseAI +import XCTest + +@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) +final class InternalPartTests: XCTestCase { + let decoder = JSONDecoder() + + func testDecodeTextPartWithThought() throws { + let json = """ + { + "text": "This is a thought.", + "thought": true + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let part = try decoder.decode(InternalPart.self, from: jsonData) + + XCTAssertEqual(part.isThought, true) + guard case let .text(text) = part.data else { + XCTFail("Decoded part is not a text part.") + return + } + XCTAssertEqual(text, "This is a thought.") + } + + func testDecodeTextPartWithoutThought() throws { + let json = """ + { + "text": "This is not a thought." + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let part = try decoder.decode(InternalPart.self, from: jsonData) + + XCTAssertNil(part.isThought) + guard case let .text(text) = part.data else { + XCTFail("Decoded part is not a text part.") + return + } + XCTAssertEqual(text, "This is not a thought.") + } + + func testDecodeInlineDataPartWithThought() throws { + let imageBase64 = + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+P+/HgAFhAJ/wlseKgAAAABJRU5ErkJggg==" + let mimeType = "image/png" + let json = """ + { + "inlineData": { + "mimeType": "\(mimeType)", + "data": "\(imageBase64)" + }, + "thought": true + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let part = try decoder.decode(InternalPart.self, from: jsonData) + + XCTAssertEqual(part.isThought, true) + guard case let .inlineData(inlineData) = part.data else { + XCTFail("Decoded part is not an inlineData part.") + return + } + XCTAssertEqual(inlineData.mimeType, mimeType) + XCTAssertEqual(inlineData.data, Data(base64Encoded: imageBase64)) + } + + func testDecodeInlineDataPartWithoutThought() throws { + let imageBase64 = "aGVsbG8=" + let mimeType = "image/png" + let json = """ + { + "inlineData": { + "mimeType": "\(mimeType)", + "data": "\(imageBase64)" + } + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let part = try decoder.decode(InternalPart.self, from: jsonData) + + XCTAssertNil(part.isThought) + guard case let .inlineData(inlineData) = part.data else { + XCTFail("Decoded part is not an inlineData part.") + return + } + XCTAssertEqual(inlineData.mimeType, mimeType) + XCTAssertEqual(inlineData.data, Data(base64Encoded: imageBase64)) + } + + func testDecodeFileDataPartWithThought() throws { + let uri = "file:///path/to/file.mp3" + let mimeType = "audio/mpeg" + let json = """ + { + "fileData": { + "fileUri": "\(uri)", + "mimeType": "\(mimeType)" + }, + "thought": true + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let part = try decoder.decode(InternalPart.self, from: jsonData) + + XCTAssertEqual(part.isThought, true) + guard case let .fileData(fileData) = part.data else { + XCTFail("Decoded part is not a fileData part.") + return + } + XCTAssertEqual(fileData.fileURI, uri) + XCTAssertEqual(fileData.mimeType, mimeType) + } + + func testDecodeFileDataPartWithoutThought() throws { + let uri = "file:///path/to/file.mp3" + let mimeType = "audio/mpeg" + let json = """ + { + "fileData": { + "fileUri": "\(uri)", + "mimeType": "\(mimeType)" + } + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let part = try decoder.decode(InternalPart.self, from: jsonData) + + XCTAssertNil(part.isThought) + guard case let .fileData(fileData) = part.data else { + XCTFail("Decoded part is not a fileData part.") + return + } + XCTAssertEqual(fileData.fileURI, uri) + XCTAssertEqual(fileData.mimeType, mimeType) + } + + func testDecodeFunctionCallPartWithThoughtSignature() throws { + let functionName = "someFunction" + let expectedThoughtSignature = "some_signature" + let json = """ + { + "functionCall": { + "name": "\(functionName)", + "args": { + "arg1": "value1" + }, + }, + "thoughtSignature": "\(expectedThoughtSignature)" + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let part = try decoder.decode(InternalPart.self, from: jsonData) + + let thoughtSignature = try XCTUnwrap(part.thoughtSignature) + XCTAssertEqual(thoughtSignature, expectedThoughtSignature) + XCTAssertNil(part.isThought) + guard case let .functionCall(functionCall) = part.data else { + XCTFail("Decoded part is not a functionCall part.") + return + } + XCTAssertEqual(functionCall.name, functionName) + XCTAssertEqual(functionCall.args, ["arg1": .string("value1")]) + } + + func testDecodeFunctionCallPartWithoutThoughtSignature() throws { + let functionName = "someFunction" + let json = """ + { + "functionCall": { + "name": "\(functionName)", + "args": { + "arg1": "value1" + } + } + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let part = try decoder.decode(InternalPart.self, from: jsonData) + + XCTAssertNil(part.isThought) + XCTAssertNil(part.thoughtSignature) + guard case let .functionCall(functionCall) = part.data else { + XCTFail("Decoded part is not a functionCall part.") + return + } + XCTAssertEqual(functionCall.name, functionName) + XCTAssertEqual(functionCall.args, ["arg1": .string("value1")]) + } + + func testDecodeFunctionCallPartWithoutArgs() throws { + let functionName = "someFunction" + let json = """ + { + "functionCall": { + "name": "\(functionName)" + } + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let part = try decoder.decode(InternalPart.self, from: jsonData) + + XCTAssertNil(part.isThought) + XCTAssertNil(part.thoughtSignature) + guard case let .functionCall(functionCall) = part.data else { + XCTFail("Decoded part is not a functionCall part.") + return + } + XCTAssertEqual(functionCall.name, functionName) + XCTAssertEqual(functionCall.args, JSONObject()) + } + + func testDecodeFunctionResponsePartWithThought() throws { + let functionName = "someFunction" + let json = """ + { + "functionResponse": { + "name": "\(functionName)", + "response": { + "output": "someValue" + } + }, + "thought": true + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let part = try decoder.decode(InternalPart.self, from: jsonData) + + XCTAssertEqual(part.isThought, true) + guard case let .functionResponse(functionResponse) = part.data else { + XCTFail("Decoded part is not a functionResponse part.") + return + } + XCTAssertEqual(functionResponse.name, functionName) + XCTAssertEqual(functionResponse.response, ["output": .string("someValue")]) + } + + func testDecodeFunctionResponsePartWithoutThought() throws { + let functionName = "someFunction" + let json = """ + { + "functionResponse": { + "name": "\(functionName)", + "response": { + "output": "someValue" + } + } + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let part = try decoder.decode(InternalPart.self, from: jsonData) + + XCTAssertNil(part.isThought) + guard case let .functionResponse(functionResponse) = part.data else { + XCTFail("Decoded part is not a functionResponse part.") + return + } + XCTAssertEqual(functionResponse.name, functionName) + XCTAssertEqual(functionResponse.response, ["output": .string("someValue")]) + } +}