field.py 3.12 KB
Newer Older
1
from .geometry import Domain
Jonathan Lambrechts's avatar
Jonathan Lambrechts committed
2
import numpy as np
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from osgeo import osr, gdal
from scipy.spatial import cKDTree
import logging
from typing import List 
import itertools
from .gmsh import _curve_sample


def _ensure_valid_points(x, pfrom, pto):
    x = np.asarray(x)[:, :2]
    if not pfrom.IsSame(pto):
        logging.info("reprojecting coordinates")
        x = osr.CoordinateTransformation(pfrom, pto).TransformPoints(x)
        x = np.asarray(x)
    return x

Jonathan Lambrechts's avatar
Jonathan Lambrechts committed
19
20
21
22
23
24
25
26

class Distance:
    """Callable evaluating the distance to a set of discretized curves.

    The curves are discretized as sets of points, then the distance to the
    closest point is computed.
    """

27
28
    def __init__(self, domain: Domain, sampling: float,
            tags: List[str] = None):
Jonathan Lambrechts's avatar
Jonathan Lambrechts committed
29
30
31
32
33
        """
        Args:
            domain: a Domain object containing the set of curves
            sampling: the interval between two consecutive sampling points.
            tags: List of physical tags specifying the curve from the domain.
34
                if None, all curves are taken into account.
Jonathan Lambrechts's avatar
Jonathan Lambrechts committed
35
36
        """
        points = []
37
        for curve in itertools.chain(domain._curves, domain._interior_curves):
Jonathan Lambrechts's avatar
Jonathan Lambrechts committed
38
            if (tags is None) or (curve.tag in tags):
39
                points.append(_curve_sample(curve,sampling))
Jonathan Lambrechts's avatar
Jonathan Lambrechts committed
40
41
42
43
        points = np.vstack(points)
        self._tree = cKDTree(points)
        self._projection = domain._projection

44
45
    def __call__(self, x: np.ndarray, projection: osr.SpatialReference
                 ) -> np.ndarray:
Jonathan Lambrechts's avatar
Jonathan Lambrechts committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
        """Compute the distance between each point of x and the curves.

        Args:
            x: the points [n,2]
            projection: the coordinate system of the points, should be
                the same coordinate system as the domain, otherwise no
                conversion is done and an exception is raised.
        Returns:
            The distance expressed in the domain unit. [n]
        """
        if not projection.IsSame(self._projection):
            raise ValueError("incompatible projection")
        x = x[:, :2]
        return self._tree.query(x)[0]


class Raster:
    """Callable to evaluate a raster field loaded from a file."""

65
    def __init__(self, filename: str):
Jonathan Lambrechts's avatar
Jonathan Lambrechts committed
66
67
        """
        Args:
68
            filename: A geotiff file or any other raster supported by gdal.
Jonathan Lambrechts's avatar
Jonathan Lambrechts committed
69
70
71
72
73
74
75
76
        """
        src_ds = gdal.Open(filename)
        self._geo_matrix = src_ds.GetGeoTransform()
        self._data = src_ds.GetRasterBand(1).ReadAsArray()
        assert(self._geo_matrix[2] == 0.)
        assert(self._geo_matrix[4] == 0.)
        self._projection = src_ds.GetProjection()

77
78
    def __call__(self, x: np.ndarray, projection: osr.SpatialReference
                 ) -> np.ndarray:
Jonathan Lambrechts's avatar
Jonathan Lambrechts committed
79
80
81
82
83
84
85
86
87
88
89
90
91
        """Evaluate the field value on each point of x.

        Keyword arguments:
            x: the points [n,2]
            projection: the coordinate system of the points
        Returns:
            The field value on points x. [n]
        """
        x = _ensure_valid_points(x, projection, self._proection)
        gm = self._geo_matrix
        xi = (x[:, 0]-gm[3])/gm[5]
        eta = (x[:, 1]-gm[0])/gm[1]
        return self._data[xi.astype(int), eta.astype(int)]