型レベル多項式

dottyの柔軟な型を利用して型レベルの多項式とその一般の足し算を定義する。

多項式の定式化

まずはenumを用いて多項式の要素を型に落とし込む。
必要な要素は項とその足し算だ。

enum Polynomial:
  case X[Coef <: Int, N <: Int]()
  case Plus[XX <: X[?, ?], YY <: Polynomial]()

Xが項、Plusがその足し算を表している。

Xの型変数Coefは係数、Nは右の肩に乗数を表している。
つまりX[2, 3]は\(2x^3\)のことである。
infix記法を使って2 X 3のように書けばより普通の書き方の様に見える。
定数項は2 X 0の様にして表せる。

Plusについては方の単純化のために左側のXXは項に固定してある。
そのため注意してほしいのはPlusは項を先頭に追加する処理を表現しており、一般の多項式同士の足し算を表現したものではない。
つまり

$$
\begin{aligned}
x^2 + (x + 1)
\end{aligned}
$$

は表せるが、

$$
\begin{aligned}
(x^2 + x) + (x + 1)
\end{aligned}
$$

は表現できない。

これらを利用するれば以下のように多項式は表現できる。

3 X 0 // 3
2 X 1 // 2x
(2 X 1) Plus (3 X 0) // 2x + 3
(1 X 2) Plus ((2 X 1) Plus (3 X 0)) // x^2 + 2x + 3

ヘルパー型

後で使いやすいようにXに対してヘルパー型を定義する。

type N[XX <: X[?, ?]] = XX match
  case ? X n => n

type C[XX <: X[?, ?]] = XX match
  case c X ? => c

それぞれXの型パラメータを抜き出す便利型である。
こんな便利型も昔はむちゃくちゃ頑張らないと書けなかったしいい時代になった。

summon[N[1 X 2] =:= 2]
summon[C[1 X 2] =:= 1]

一般の多項式同士の足し算

Plusでは一般の足し算が表現できないことは上で述べた。
ここで一般の足し算を定義したい。
更に項は乗数のオーダーでソートされてほしい。
そのような型演算は以下のようにして定義できる。

import scala.compiletime.ops.int
import scala.compiletime.ops.any.*
import Polynomial.*


type +[A <: Polynomial, B <: Polynomial] <: Polynomial = A match
  case c X n =>
    B match
      case cc X nn =>
        n == nn match
          case true => int.+[c, cc] X n
          case false =>
            int.>[n, nn] match
              case true  => A Plus B
              case false => B Plus A
      case xx Plus yy =>
        int.>[n, N[xx]] match
          case true => A Plus B
          case false =>
            n == N[xx] match
              case true  => (int.+[c, C[xx]] X n) Plus yy
              case false => xx Plus (A + yy)
  case xx Plus yy =>
    B match
      case cc X nn =>
        N[xx] == nn match
          case true => (int.+[C[xx], cc] X nn) Plus yy
          case false =>
            int.>[N[xx], nn] match
              case true  => xx Plus (B + yy)
              case false => B Plus A
      case xxx Plus yyy =>
        N[xx] == N[xxx] match
          case true => (int.+[C[xx], C[xxx]] X N[xx]) Plus (yy + yyy)
          case false =>
            int.>[N[xx], N[xxx]] match
              case true  => xx Plus (yy + B)
              case false => xxx Plus (yyy + A)

Match Typeが暴れている。
これを使えば以下のように一般の形での多項式の足し算ができる。

summon[(3 X 1) + (2 X 0) =:= ((3 X 1) Plus (2 X 0))]
summon[(3 X 2) + (4 X 1) + (2 X 0) =:= ((3 X 2) Plus ((4 X 1) Plus (2 X 0)))]

summon[(3 X 1) + (4 X 1) =:= (7 X 1)] // 3x + 4x = 7x
summon[(3 X 1) + (4 X 2) =:= (4 X 2) + (3 X 1)] // 3x + 4x^2 = 4x^2 + 3x
summon[(3 X 2) + ((4 X 1) + (2 X 0)) =:= (3 X 2) + (4 X 1) + (2 X 0)] // 3x^2 + (4x + 2) = 3x^2 + 4x + 2
summon[(3 X 1) + ((4 X 1) + (2 X 0)) =:= (7 X 1) + (2 X 0)] // 3x + (4x + 2) = 7x + 2
summon[(3 X 0) + ((4 X 1) + (2 X 0)) =:= (4 X 1) + (5 X 0)] // 3 + (4x + 2) = 4x + 5
summon[((3 X 1) + (2 X 0)) + (4 X 1) =:= (7 X 1) + (2 X 0)] // (3x + 2) + 4x = 7x + 2
summon[((3 X 1) + (2 X 0)) + (4 X 2) =:= (4 X 2) + (3 X 1) + (2 X 0)] // (3x + 2) + 4x^2 = 4x^2 + 3x + 2
summon[(3 X 1) + ((4 X 2) + (2 X 0)) =:= ((4 X 2) + (3 X 1) + (2 X 0))] // 3x + (4x^2 + 2) = 4x^2 + 3x + 2
summon[
  ((5 X 3) + (3 X 1)) + ((4 X 2) + (2 X 0)) =:= (5 X 3) + (4 X 2) + (3 X 1) + (2 X 0)
] // (5x^3 + 3x) + (4x^2 +2) = 5x^3 + 4x^2 + 3x + 2

多分あってる。

終わり

dottyの登場によって型レベル演算はscalaコンパイラを熟知したhackyなものではなく、どれだけ型レベルのifを書けるかの勝負になっている感がある。
最後に全コード。

import scala.compiletime.ops.int
import scala.compiletime.ops.any.*
import Polynomial.*

enum Polynomial:
  case X[Coef <: Int, N <: Int]()
  case Plus[XX <: X[?, ?], YY <: Polynomial]()

type N[XX <: X[?, ?]] = XX match
  case ? X n => n

type C[XX <: X[?, ?]] = XX match
  case c X ? => c

type +[A <: Polynomial, B <: Polynomial] <: Polynomial = A match
  case c X n =>
    B match
      case cc X nn =>
        n == nn match
          case true => int.+[c, cc] X n
          case false =>
            int.>[n, nn] match
              case true  => A Plus B
              case false => B Plus A
      case xx Plus yy =>
        int.>[n, N[xx]] match
          case true => A Plus B
          case false =>
            n == N[xx] match
              case true  => (int.+[c, C[xx]] X n) Plus yy
              case false => xx Plus (A + yy)
  case xx Plus yy =>
    B match
      case cc X nn =>
        N[xx] == nn match
          case true => (int.+[C[xx], cc] X nn) Plus yy
          case false =>
            int.>[N[xx], nn] match
              case true  => xx Plus (B + yy)
              case false => B Plus A
      case xxx Plus yyy =>
        N[xx] == N[xxx] match
          case true => (int.+[C[xx], C[xxx]] X N[xx]) Plus (yy + yyy)
          case false =>
            int.>[N[xx], N[xxx]] match
              case true  => xx Plus (yy + B)
              case false => xxx Plus (yyy + A)

object Polynomial:
  summon[N[1 X 2] =:= 2]
  summon[C[1 X 2] =:= 1]

  summon[(3 X 1) + (4 X 1) =:= (7 X 1)]
  summon[(3 X 1) + (2 X 0) =:= ((3 X 1) Plus (2 X 0))]
  summon[(3 X 1) + (4 X 2) =:= (4 X 2) + (3 X 1)]
  summon[(3 X 2) + (4 X 1) + (2 X 0) =:= ((3 X 2) Plus ((4 X 1) Plus (2 X 0)))]
  summon[(3 X 2) + ((4 X 1) + (2 X 0)) =:= (3 X 2) + (4 X 1) + (2 X 0)]
  summon[(3 X 1) + ((4 X 1) + (2 X 0)) =:= (7 X 1) + (2 X 0)]
  summon[(3 X 0) + ((4 X 1) + (2 X 0)) =:= (4 X 1) + (5 X 0)]
  summon[((3 X 1) + (2 X 0)) + (4 X 1) =:= (7 X 1) + (2 X 0)]
  summon[((3 X 1) + (2 X 0)) + (4 X 2) =:= (4 X 2) + (3 X 1) + (2 X 0)]
  summon[(3 X 1) + ((4 X 2) + (2 X 0)) =:= ((4 X 2) + (3 X 1) + (2 X 0))]
  summon[
    ((5 X 3) + (3 X 1)) + ((4 X 2) + (2 X 0)) =:= (5 X 3) + (4 X 2) + (3 X 1) + (2 X 0)
  ]
リテラル型(literal type) [初級-中級向け]Scala基本APIを完全に理解するシリーズ② -Either編- Scala summoner pattern
View Comments
There are currently no comments.