|
| 1 | +""" |
| 2 | + LCMType |
| 3 | +
|
| 4 | +Supertype of concrete Julia `struct`s that represent LCM message types. |
| 5 | +
|
| 6 | +Subtypes must be `mutable struct`s and may use the following field types: |
| 7 | +
|
| 8 | +* `Bool` |
| 9 | +* numeric types: `Int8`, `Int16`, `Int32`, `Int64`, `Float32`, `Float64` |
| 10 | +* bytes (encoded in the same was as `Int8`): `UInt8`; |
| 11 | +* `String`; |
| 12 | +* another `LCMType`; |
| 13 | +* `Vector` or a subtype of `StaticVector`, for which the element type must also be |
| 14 | +one of the previously specified types or another `Vector` or `StaticVector`. |
| 15 | +
|
| 16 | +The following methods must be defined for a concrete subtype of `LCMType` (say `MyType`): |
| 17 | +
|
| 18 | +* `check_valid(x::MyType)` |
| 19 | +* `size_fields(::Type{MyType})` |
| 20 | +* `fingerprint(::Type{MyType})` |
| 21 | +* `Base.resize!(x::MyType)` |
| 22 | +
|
| 23 | +Any size fields must come **before** the `Vector` fields to which they correspond. |
| 24 | +
|
| 25 | +Note that ideally, all of these methods would be generated from the LCM message type |
| 26 | +definition, but that is currently not the case. |
| 27 | +""" |
| 28 | +abstract type LCMType end |
| 29 | + |
| 30 | +""" |
| 31 | +size_fields(x::Type{T}) where T<:LCMType |
| 32 | +
|
| 33 | +Returns a tuple of `Symbol`s corresponding to the fields of `T` that represent vector dimensions. |
| 34 | +""" |
| 35 | +size_fields(x::Type{T}) where {T<:LCMType} = error("size_fields method not defined for LCMType $T.") |
| 36 | + |
| 37 | +""" |
| 38 | +check_valid(x::LCMType) |
| 39 | +
|
| 40 | +Check that `x` is a valid LCM type. For example, check that array lengths are correct. |
| 41 | +""" |
| 42 | +check_valid(x::LCMType) = error("check_valid method not defined for LCMType $(typeof(x)).") |
| 43 | + |
| 44 | +""" |
| 45 | +fingerprint(::Type{T}) where T<:LCMType |
| 46 | +
|
| 47 | +Return the fingerprint of LCM type `T` as an `SVector{8, UInt8}`. |
| 48 | +""" |
| 49 | +fingerprint(::Type{T}) where {T<:LCMType} = error("fingerprint method not defined for LCMType $T.") |
| 50 | + |
| 51 | +# Types that are encoded in network byte order, as dictated by the LCM type specification. |
| 52 | +const NETWORK_BYTE_ORDER_TYPES = Union{Int8, Int16, Int32, Int64, Float32, Float64, UInt8} |
| 53 | + |
| 54 | +# Default values for all of the possible field types of an LCM type: |
| 55 | +default_value(::Type{Bool}) = false |
| 56 | +default_value(::Type{T}) where {T<:NETWORK_BYTE_ORDER_TYPES} = zero(T) |
| 57 | +default_value(::Type{String}) = "" |
| 58 | +default_value(::Type{T}) where {T<:Vector} = T() |
| 59 | +default_value(::Type{SV}) where {N, T, SV<:StaticVector{N, T}} = SV(ntuple(i -> default_value(T), Val(N))) |
| 60 | +default_value(::Type{T}) where {T<:LCMType} = T() |
| 61 | + |
| 62 | +# Generated default constructor for LCMType subtypes |
| 63 | +@generated function (::Type{T})() where T<:LCMType |
| 64 | + constructor_arg_exprs = [:(default_value(fieldtype(T, $i))) for i = 1 : fieldcount(T)] |
| 65 | + :(T($(constructor_arg_exprs...))) |
| 66 | +end |
| 67 | + |
| 68 | +# Fingerprint check |
| 69 | +struct FingerprintException <: Exception |
| 70 | + T::Type |
| 71 | +end |
| 72 | + |
| 73 | +@noinline function Base.showerror(io::IO, e::FingerprintException) |
| 74 | + print(io, "LCM message fingerprint did not match type ", e.T, ". ") |
| 75 | + print(io, "This means that you are trying to decode the wrong message type, or a different version of the message type.") |
| 76 | +end |
| 77 | + |
| 78 | +function check_fingerprint(io::IO, ::Type{T}) where T<:LCMType |
| 79 | + decodefield(io, SVector{8, UInt8}) == fingerprint(T) || throw(FingerprintException(T)) |
| 80 | +end |
| 81 | + |
| 82 | +# Decoding |
| 83 | +function decode!(x::LCMType, io::IO) |
| 84 | + check_fingerprint(io, typeof(x)) |
| 85 | + decodefield!(x, io) |
| 86 | +end |
| 87 | + |
| 88 | +""" |
| 89 | + decode_in_place(T) |
| 90 | +
|
| 91 | +Specify whether type `T` should be decoded in place, i.e. whether to use a |
| 92 | +`decodefield!` method instead of a `decodefield` method. |
| 93 | +""" |
| 94 | +function decode_in_place end |
| 95 | + |
| 96 | +Base.@pure decode_in_place(::Type{<:LCMType}) = true |
| 97 | +@generated function decodefield!(x::T, io::IO) where T<:LCMType |
| 98 | + field_assignments = Vector{Expr}(fieldcount(x)) |
| 99 | + for (i, fieldname) in enumerate(fieldnames(x)) |
| 100 | + F = fieldtype(x, fieldname) |
| 101 | + field_assignments[i] = quote |
| 102 | + if decode_in_place($F) |
| 103 | + decodefield!(x.$fieldname, io) |
| 104 | + else |
| 105 | + x.$fieldname = decodefield(io, $F) |
| 106 | + end |
| 107 | + # $(QuoteNode(fieldname)) ∈ size_fields(T) && resize!(x) # allocates! |
| 108 | + any($(QuoteNode(fieldname)) .== size_fields(T)) && resize!(x) |
| 109 | + end |
| 110 | + end |
| 111 | + return quote |
| 112 | + $(field_assignments...) |
| 113 | + return x |
| 114 | + end |
| 115 | +end |
| 116 | + |
| 117 | +Base.@pure decode_in_place(::Type{Bool}) = false |
| 118 | +decodefield(io::IO, ::Type{Bool}) = read(io, UInt8) == 0x01 |
| 119 | + |
| 120 | +Base.@pure decode_in_place(::Type{<:NETWORK_BYTE_ORDER_TYPES}) = false |
| 121 | +decodefield(io::IO, ::Type{T}) where {T<:NETWORK_BYTE_ORDER_TYPES} = ntoh(read(io, T)) |
| 122 | + |
| 123 | +Base.@pure decode_in_place(::Type{String}) = false |
| 124 | +function decodefield(io::IO, ::Type{String}) |
| 125 | + len = ntoh(read(io, UInt32)) |
| 126 | + ret = String(read(io, len - 1)) |
| 127 | + read(io, UInt8) # strip off null |
| 128 | + ret |
| 129 | +end |
| 130 | + |
| 131 | +Base.@pure decode_in_place(::Type{<:Vector}) = true |
| 132 | +function decodefield!(x::Vector{T}, io::IO) where T |
| 133 | + @inbounds for i in eachindex(x) |
| 134 | + if decode_in_place(T) |
| 135 | + isassigned(x, i) || (x[i] = default_value(T)) |
| 136 | + decodefield!(x[i], io) |
| 137 | + else |
| 138 | + x[i] = decodefield(io, T) |
| 139 | + end |
| 140 | + end |
| 141 | + x |
| 142 | +end |
| 143 | + |
| 144 | +Base.@pure decode_in_place(::Type{SV}) where {SV<:StaticVector} = decode_in_place(eltype(SV)) |
| 145 | +function decodefield!(x::StaticVector, io::IO) |
| 146 | + decode_in_place(eltype(x)) || error() |
| 147 | + @inbounds for i in eachindex(x) |
| 148 | + decodefield!(x[i], io) |
| 149 | + end |
| 150 | + x |
| 151 | +end |
| 152 | +@generated function decodefield(io::IO, ::Type{SV}) where {N, T, SV<:StaticVector{N, T}} |
| 153 | + constructor_arg_exprs = [:(decodefield(io, T)) for i = 1 : N] |
| 154 | + return quote |
| 155 | + decode_in_place(T) && error() |
| 156 | + SV(tuple($(constructor_arg_exprs...))) |
| 157 | + end |
| 158 | +end |
| 159 | + |
| 160 | + |
| 161 | +# Encoding |
| 162 | +""" |
| 163 | + encode(io::IO, x::LCMType) |
| 164 | +
|
| 165 | +Write an LCM byte representation of `x` to `io`. |
| 166 | +""" |
| 167 | +function encode(io::IO, x::LCMType) |
| 168 | + encodefield(io, fingerprint(typeof(x))) |
| 169 | + encodefield(io, x) |
| 170 | +end |
| 171 | + |
| 172 | +@generated function encodefield(io::IO, x::LCMType) |
| 173 | + encode_exprs = Vector{Expr}(fieldcount(x)) |
| 174 | + for (i, fieldname) in enumerate(fieldnames(x)) |
| 175 | + encode_exprs[i] = :(encodefield(io, x.$fieldname)) |
| 176 | + end |
| 177 | + quote |
| 178 | + check_valid(x) |
| 179 | + $(encode_exprs...) |
| 180 | + io |
| 181 | + end |
| 182 | +end |
| 183 | + |
| 184 | +encodefield(io::IO, x::Bool) = write(io, ifelse(x, 0x01, 0x00)) |
| 185 | + |
| 186 | +encodefield(io::IO, x::NETWORK_BYTE_ORDER_TYPES) = write(io, hton(x)) |
| 187 | + |
| 188 | +function encodefield(io::IO, x::String) |
| 189 | + write(io, hton(UInt32(length(x) + 1))) |
| 190 | + write(io, x) |
| 191 | + write(io, UInt8(0)) |
| 192 | +end |
| 193 | + |
| 194 | +function encodefield(io::IO, A::AbstractVector) |
| 195 | + for x in A |
| 196 | + encodefield(io, x) |
| 197 | + end |
| 198 | +end |
| 199 | + |
| 200 | +# Sugar |
| 201 | +encode(data::Vector{UInt8}, x::LCMType) = encode(BufferedOutputStream(data), x) |
| 202 | +encode(x::LCMType) = (stream = BufferedOutputStream(); encode(stream, x); flush(stream); take!(stream)) |
| 203 | + |
| 204 | +decode!(x::LCMType, data::Vector{UInt8}) = decode!(x, BufferedInputStream(data)) |
| 205 | +decode(data::Vector{UInt8}, ::Type{T}) where {T<:LCMType} = decode!(T(), data) |
0 commit comments