Skip to content

Commit cc02d7a

Browse files
authored
Document main concepts (#306)
This PR adds documentation for the main concepts in the functional RL package.
1 parent efcf71b commit cc02d7a

File tree

12 files changed

+122
-17
lines changed

12 files changed

+122
-17
lines changed

scala-rl-core/src/main/scala/com/scalarl/Agent.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
*
66
* The nodes are:
77
*
8-
* \- State nodes, with edges leading out to each possible action. \- Action nodes, with edges
9-
* leading out to (reward, state) pairs.
8+
* - State nodes, with edges leading out to each possible action.
9+
* - Action nodes, with edges leading out to (reward, state) pairs.
1010
*
1111
* Policies are maps of State => Map[A, Weight]. I don't know that I have a policy that is NOT
1212
* that.
@@ -15,8 +15,8 @@
1515
*
1616
* So to get the value of an ACTION node you need either:
1717
*
18-
* \- To track it directly, with an ActionValueFn, or \- to estimate it with some model of the
19-
* dynamics of the system.
18+
* - To track it directly, with an ActionValueFn, or
19+
* - to estimate it with some model of the dynamics of the system.
2020
*
2121
* TODO - Key questions: \- Can I rethink the interface here? Can StateValueFn instances ONLY be
2222
* calculated for... rings where the weights add up to 1? "Affine combination" is the key idea

scala-rl-core/src/main/scala/com/scalarl/SARS.scala

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,38 @@ package com.scalarl
22

33
import cats.Functor
44

5-
/** Chunk that you get back for playing an episode.
5+
/** Represents a single step in a reinforcement learning episode.
6+
*
7+
* SARS stands for State-Action-Reward-State, capturing the complete transition:
8+
* - The initial state the agent was in
9+
* - The action the agent took
10+
* - The reward received for taking that action
11+
* - The next state the environment transitioned to
612
*/
713
final case class SARS[Obs, A, R, S[_]](
814
state: State[Obs, A, R, S],
915
action: A,
1016
reward: R,
1117
nextState: State[Obs, A, R, S]
1218
) {
19+
20+
/** Maps the observation type of this SARS to a new type.
21+
*
22+
* @param f
23+
* The function to transform the observation from type Obs to type P
24+
* @param S
25+
* Evidence that S has a Functor instance
26+
*/
1327
def mapObservation[P](f: Obs => P)(implicit S: Functor[S]): SARS[P, A, R, S] =
1428
SARS(state.mapObservation(f), action, reward, nextState.mapObservation(f))
1529

30+
/** Maps the reward type of this SARS to a new type.
31+
*
32+
* @param f
33+
* The function to transform the reward from type R to type T
34+
* @param S
35+
* Evidence that S has a Functor instance
36+
*/
1637
def mapReward[T](f: R => T)(implicit S: Functor[S]): SARS[Obs, A, T, S] =
1738
SARS(state.mapReward(f), action, f(reward), nextState.mapReward(f))
1839

scala-rl-core/src/main/scala/com/scalarl/State.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,17 @@ trait State[Obs, A, @specialized(Int, Long, Float, Double) R, M[_]] { self =>
4444
def actions: Set[A] = dynamics.keySet
4545
def act(action: A): M[(R, This)] = dynamics.getOrElse(action, invalidMove)
4646

47-
/** Returns a list of possible actions to take from this state. To specify the terminal state,
48-
* return an empty set.
49-
*/
5047
def isTerminal: Boolean = actions.isEmpty
5148

49+
/** Maps the observation type of this state to a new type.
50+
*
51+
* @param f
52+
* The function to transform the observation from type Obs to type P
53+
* @param M
54+
* Evidence that M has a Functor instance
55+
* @return
56+
* A new State with observations of type P but the same actions and rewards
57+
*/
5258
def mapObservation[P](
5359
f: Obs => P
5460
)(implicit M: Functor[M]): State[P, A, R, M] =

scala-rl-core/src/main/scala/com/scalarl/Time.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
package com.scalarl
22

3+
/** A value class wrapper around Long that allows us to talk about time ticking and evolution in a
4+
* type-safe way.
5+
*
6+
* This class provides methods for incrementing time, comparing time values, and basic arithmetic
7+
* operations, while maintaining type safety through the AnyVal wrapper.
8+
*/
9+
310
case class Time(value: Long) extends AnyVal {
411
def tick: Time = Time(value + 1)
512
def -(r: Time) = value - r.value

scala-rl-core/src/main/scala/com/scalarl/Util.scala

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,30 @@ import scala.language.higherKinds
1212
object Util {
1313
import cats.syntax.functor._
1414

15+
/** Here we provide various "missing" typeclass instances sewing together algebird typeclasses and
16+
* implementing typeclasses for rainier types.
17+
*/
1518
object Instances {
19+
// this lets us sort AveragedValue instances...
1620
implicit val averageValueOrd: Ordering[AveragedValue] =
1721
Ordering.by(_.value)
1822

23+
// shows how to extract the averaged value out from the accumulating data structure
1924
implicit val avToDouble: ToDouble[AveragedValue] =
2025
ToDouble.instance(_.value)
2126

27+
// Module instance, representing a module that can scale AveragedValue by some scalar double.
2228
implicit val avModule: Module[Double, AveragedValue] =
2329
Module.from((r, av) => AveragedValue(av.count, r * av.value))
2430

31+
// easy, just expose this implicitly.
2532
implicit val realRing: Ring[Real] = RealRing
2633

34+
// trivial VectorSpace, showing that the cats.Id monad (and any Ring R) form a vectorspace.
2735
implicit def idVectorSpace[R](implicit R: Ring[R]): VectorSpace[R, Id] =
2836
VectorSpace.from[R, Id](R.times(_, _))
2937

38+
// Ring instance for rainer Reals.
3039
object RealRing extends Ring[Real] {
3140
override def one = Real.one
3241
override def zero = Real.zero
@@ -37,11 +46,33 @@ object Util {
3746
}
3847
}
3948

40-
def confine[A](a: A, min: A, max: A)(implicit ord: Ordering[A]): A =
49+
/** Clamps a value between a minimum and maximum value.
50+
*
51+
* This function ensures that the input value `a` is not less than `min` and not greater than
52+
* `max`, returning the clamped value.
53+
*
54+
* @param a
55+
* The value to clamp.
56+
* @param min
57+
* The minimum value.
58+
*/
59+
def clamp[A](a: A, min: A, max: A)(implicit ord: Ordering[A]): A =
4160
ord.min(ord.max(a, min), max)
4261

62+
/** Creates a Map from a set of keys using a function to generate values.
63+
*
64+
* This function takes a set of keys and a function that maps each key to a value, returning a
65+
* Map with the keys and their corresponding values.
66+
*
67+
* @param keys
68+
*/
4369
def makeMap[K, V](keys: Set[K])(f: K => V): Map[K, V] = makeMapUnsafe(keys)(f)
4470

71+
/** similar to makeMap, but doesn't guarantee that there are not duplicate keys. If keys contains
72+
* duplicates, later keys override earlier keys.
73+
*
74+
* @param keys
75+
*/
4576
def makeMapUnsafe[K, V](keys: TraversableOnce[K])(f: K => V): Map[K, V] =
4677
keys.foldLeft(Map.empty[K, V]) { case (m, k) =>
4778
m.updated(k, f(k))
@@ -53,21 +84,28 @@ object Util {
5384
def updateWith[K, V](m: Map[K, V], k: K)(f: Option[V] => V): Map[K, V] =
5485
m.updated(k, f(m.get(k)))
5586

87+
/** Merges a key and a value into a map using a semigroup to combine values. */
5688
def mergeV[K, V: Semigroup](m: Map[K, V], k: K, delta: V): Map[K, V] =
5789
updateWith(m, k) {
5890
case None => delta
5991
case Some(v) => Semigroup.plus[V](v, delta)
6092
}
6193

94+
/** Finds the keys with the maximum values in a map.
95+
*/
6296
def maxKeys[A, B: Ordering](m: Map[A, B]): Set[A] = allMaxBy(m.keySet)(m(_))
6397

98+
/** Returns the set of keys that map (via `f`) to the maximal B, out of all `as` transformed.
99+
*/
64100
def allMaxBy[A, B: Ordering](as: Set[A])(f: A => B): Set[A] =
65101
if (as.isEmpty) Set.empty
66102
else {
67103
val maxB = f(as.maxBy(f))
68104
as.filter(a => Ordering[B].equiv(maxB, f(a)))
69105
}
70106

107+
/** Iterates a monadic function `f` `n` of times using the starting value `a`.
108+
*/
71109
def iterateM[M[_], A](
72110
n: Int
73111
)(a: A)(f: A => M[A])(implicit M: Monad[M]): M[A] =

scala-rl-core/src/main/scala/com/scalarl/algebra/AffineCombination.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ package algebra
44
import cats.Id
55
import com.twitter.algebird.Ring
66

7-
/** Another attempt at a better thing, here... but I don't know if this solves my problem of needing
8-
* to compose up the stack,
7+
/** This is not currently used! Another attempt at a better thing, here... but I don't know if this
8+
* solves my problem of needing to compose up the stack.
9+
*
10+
* I had a note about this in [[Agent]].
911
*/
1012
trait AffineCombination[M[_], R] {
1113
implicit def ring: Ring[R]

scala-rl-core/src/main/scala/com/scalarl/algebra/Module.scala

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,40 @@ package algebra
33

44
import com.twitter.algebird.{Group, Ring, VectorSpace}
55

6-
/** This class represents a module. For the required properties see:
6+
/** This class represents an abstract-algebraic "module". A module is a generalization of vector
7+
* spaces that allows scalars to come from a ring instead of a field. It consists of:
78
*
8-
* https://en.wikipedia.org/wiki/Module_(mathematics)
9+
* - An abelian group (G, +) representing the elements that can be scaled
10+
* - A ring (R, +, *) representing the scalars
11+
* - A scaling operation R × G → G that satisfies:
12+
* - r(g₁ + g₂) = rg₁ + rg₂ (distributivity over group addition)
13+
* - (r₁ + r₂)g = r₁g + r₂g (distributivity over ring addition)
14+
* - (r₁r₂)g = r₁(r₂g) (compatibility with ring multiplication)
15+
* - 1g = g (identity scalar)
16+
*
17+
* For more details see: https://en.wikipedia.org/wiki/Module_(mathematics)
918
*/
1019
object Module {
20+
// the default module!
1121
type DModule[T] = Module[Double, T]
1222

23+
/** This method is used to get the default module for a given type.
24+
*
25+
* @param M
26+
* The module to get.
27+
* @return
28+
* The default module for the given type.
29+
*/
1330
@inline final def apply[R, G](implicit M: Module[R, G]): Module[R, G] = M
1431

32+
/** supplies an implicit module, given an implicitly-available Ring for some type R.
33+
*/
1534
implicit def ringModule[R: Ring]: Module[R, R] = from(Ring.times(_, _))
1635

36+
/** Given an implicit ring and group, accepts a scaleFn that shows how to perform scalar
37+
* multiplication between elements of the ring and the group and returns a new module over R and
38+
* G.
39+
*/
1740
def from[R, G](
1841
scaleFn: (R, G) => G
1942
)(implicit R: Ring[R], G: Group[G]): Module[R, G] =
@@ -24,6 +47,9 @@ object Module {
2447
if (R.isNonZero(r)) scaleFn(r, g) else G.zero
2548
}
2649

50+
/* Algebird's vector space is generic on the container type C, and implicitly pulls in a group on
51+
C[F]. We are a little more general.
52+
*/
2753
def fromVectorSpace[F, C[_]](implicit
2854
R: Ring[F],
2955
V: VectorSpace[F, C]

scala-rl-core/src/main/scala/com/scalarl/algebra/Weight.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,5 @@ object Weight {
2222

2323
implicit val timesMonoid: Monoid[Weight] = Monoid.from(One)(_ * _)
2424
implicit val ord: Ordering[Weight] = Ordering.by(_.w)
25+
implicit val toDouble: ToDouble[Weight] = ToDouble.instance(_.w)
2526
}

scala-rl-core/src/main/scala/com/scalarl/package.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,9 @@ package com
33
/** Functional reinforcement learning in Scala.
44
*/
55
package object scalarl {
6+
7+
/** Type alias for [[com.scalarl.rainier.Categorical]], which represents a finite discrete
8+
* probability distribution.
9+
*/
610
type Cat[+T] = rainier.Categorical[T]
711
}

scala-rl-core/src/main/scala/com/scalarl/rainier/Categorical.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import com.scalarl.algebra.ToDouble
1010
import scala.annotation.tailrec
1111
import scala.collection.immutable.Queue
1212

13-
/** A finite discrete distribution.
13+
/** Identical to rainier's `Categorical`, except written with `Double` instead of `Real`.
1414
*
1515
* @param pmfSeq
1616
* A map with keys corresponding to the possible outcomes and values corresponding to the

0 commit comments

Comments
 (0)